From 2cf88b433209511e9577805a6f9d543a6ea7dfe2 Mon Sep 17 00:00:00 2001 From: "fangyaozheng@bytedance.com" Date: Thu, 7 Aug 2025 20:46:24 +0800 Subject: [PATCH 1/6] refine knowledgebase and memory logs --- veadk/knowledgebase/knowledgebase.py | 21 +++-- .../knowledgebase_database_adapter.py | 87 +++++++++++-------- veadk/memory/long_term_memory.py | 41 +++++++-- veadk/memory/memory_database_adapter.py | 8 ++ 4 files changed, 107 insertions(+), 50 deletions(-) diff --git a/veadk/knowledgebase/knowledgebase.py b/veadk/knowledgebase/knowledgebase.py index 7d69dfd5..8d382d93 100644 --- a/veadk/knowledgebase/knowledgebase.py +++ b/veadk/knowledgebase/knowledgebase.py @@ -26,10 +26,13 @@ class KnowledgeBase: def __init__( self, backend: Literal["local", "opensearch", "viking", "redis", "mysql"] = "local", - top_k: int = 5, + top_k: int = 10, db_config=None, ): - logger.debug(f"Create knowledgebase, backend is {backend}") + logger.info( + f"Initializing knowledgebase: backend={backend} top_k={top_k} db_config={db_config}" + ) + self.backend = backend self.top_k = top_k @@ -38,6 +41,10 @@ def __init__( database_client=self.db_client ) + logger.info( + f"Initialized knowledgebase: db_client={self.db_client} adapter={self.adapter}" + ) + def add( self, data: str | list[str] | TextIO | BinaryIO | bytes, @@ -51,9 +58,8 @@ def add( for example, if you read the file stream of a pdf, then you need to pass an additional parameter file_ext = '.pdf'. """ kwargs.pop("session_id", None) # remove session_id - self.adapter.add( - data, app_name, user_id="user_id", session_id="session_id", **kwargs - ) + logger.info(f"Adding documents to knowledgebase: app_name={app_name}") + self.adapter.add(data, app_name, **kwargs) def search(self, query: str, app_name: str, top_k: int = None) -> list[str]: """Retrieve documents similar to the query text in the vector database. @@ -66,9 +72,10 @@ def search(self, query: str, app_name: str, top_k: int = None) -> list[str]: """ top_k = self.top_k if top_k is None else top_k - result = self.adapter.query( - query=query, app_name=app_name, user_id="user_id", top_k=top_k + logger.info( + f"Searching knowledgebase: app_name={app_name} query={query} top_k={top_k}" ) + result = self.adapter.query(query=query, app_name=app_name, top_k=top_k) if len(result) == 0: logger.warning(f"No documents found in knowledgebase. Query: {query}") return result diff --git a/veadk/knowledgebase/knowledgebase_database_adapter.py b/veadk/knowledgebase/knowledgebase_database_adapter.py index d81e28e8..bd493556 100644 --- a/veadk/knowledgebase/knowledgebase_database_adapter.py +++ b/veadk/knowledgebase/knowledgebase_database_adapter.py @@ -35,6 +35,14 @@ def format_collection_name(collection_name: str) -> str: return re.sub(r"[^a-z0-9_]", "", replaced_str).lower() +def build_index(**kwargs): + """ + Build the index name for the knowledgebase. + """ + # TODO + ... + + def get_knowledgebase_adapter(backend: str): if backend == DatabaseBackend.REDIS: return KnowledgebaseKVDatabaseAdapter @@ -55,30 +63,22 @@ class KnowledgebaseKVDatabaseAdapter(BaseModel): database_client: BaseDatabase - def add(self, content: list[str], app_name: str, user_id: str, session_id: str): - """Add texts to Redis. - - Key: app_name - Field: app_name:user_id - Value: text in List - """ - # key = f"{app_name}:{user_id}" + def add(self, content: list[str], app_name: str, **kwargs): key = f"{app_name}" + logger.debug(f"Adding documents to Redis database: key={key}") try: for _content in content: self.database_client.add(key, _content) - logger.debug( - f"Successfully added {len(content)} texts to Redis list key `{key}`." - ) + logger.debug(f"Added {len(content)} texts to Redis database: key={key}") except Exception as e: - logger.error(f"Failed to add texts to Redis list key `{key}`: {e}") + logger.error(f"Failed to add texts to Redis database key `{key}`: {e}") raise e - def query(self, query: str, app_name: str, user_id: str, **kwargs): - # key = f"{app_name}:{user_id}" + def query(self, query: str, app_name: str, **kwargs): key = f"{app_name}" top_k = 10 + logger.debug(f"Querying Redis database: key={key} query={query}") try: result = self.database_client.query(key, query) @@ -87,9 +87,8 @@ def query(self, query: str, app_name: str, user_id: str, **kwargs): logger.error(f"Failed to search from Redis list key '{key}': {e}") raise e - def delete(self, app_name: str, user_id: str, session_id: str): + def delete(self, app_name: str, **kwargs): try: - # key = f"{app_name}:{user_id}:{session_id}" key = f"{app_name}" self.database_client.delete(key=key) logger.info(f"Successfully deleted data for app {app_name}") @@ -104,6 +103,8 @@ class KnowledgebaseRelationalDatabaseAdapter(BaseModel): database_client: BaseDatabase def create_table(self, table_name: str): + logger.debug(f"Creating table for SQL database: table_name={table_name}") + sql = f""" CREATE TABLE `{table_name}` ( `id` BIGINT AUTO_INCREMENT PRIMARY KEY, @@ -113,8 +114,11 @@ def create_table(self, table_name: str): """ self.database_client.add(sql) - def add(self, content: list[str], app_name: str, user_id: str, session_id: str): + def add(self, content: list[str], app_name: str, **kwargs): table = app_name + logger.debug( + f"Adding documents to SQL database: table_name={table} content_len={len(content)}" + ) if not self.database_client.table_exists(table): logger.warning(f"Table {table} does not exist, creating...") @@ -126,16 +130,19 @@ def add(self, content: list[str], app_name: str, user_id: str, session_id: str): VALUES (%s); """ self.database_client.add(sql, params=(_content,)) - logger.info(f"Successfully added {len(content)} texts to table {table}.") + logger.debug(f"Added {len(content)} texts to table {table}.") - def query(self, query: str, app_name: str, user_id: str, **kwargs): + def query(self, query: str, app_name: str, **kwargs): """Search content from table app_name.""" table = app_name top_k = 10 + logger.debug( + f"Querying SQL database: table_name={table} query={query} top_k={top_k}" + ) if not self.database_client.table_exists(table): logger.warning( - f"querying {query}, but table `{table}` does not exist, returning empty list." + f"Querying SQL database, but table `{table}` does not exist, returning empty list." ) return [] @@ -145,7 +152,7 @@ def query(self, query: str, app_name: str, user_id: str, **kwargs): results = self.database_client.query(sql) return [item["data"] for item in results] - def delete(self, app_name: str, user_id: str, session_id: str): + def delete(self, app_name: str, **kwargs): table = app_name try: self.database_client.delete(table=table) @@ -160,23 +167,23 @@ class KnowledgebaseVectorDatabaseAdapter(BaseModel): database_client: BaseDatabase - def add(self, content: list[str], app_name: str, user_id: str, session_id: str): - # collection_name = format_collection_name(f"{app_name}_{user_id}") - # knowledgebase is application specific + def add(self, content: list[str], app_name: str, **kwargs): + logger.debug( + f"Adding documents to vector database: app_name={app_name} content_len={len(content)}" + ) collection_name = format_collection_name(f"{app_name}") self.database_client.add(content, collection_name=collection_name) - def query(self, query: str, app_name: str, user_id: str, **kwargs): - # collection_name = format_collection_name(f"{app_name}_{user_id}") - # knowledgebase is application specific + def query(self, query: str, app_name: str, **kwargs): + logger.debug( + f"Querying vector database: collection_name={app_name} query={query}" + ) collection_name = format_collection_name(f"{app_name}") return self.database_client.query( query, collection_name=collection_name, **kwargs ) - def delete(self, app_name: str, user_id: str, session_id: str): - # collection_name = format_collection_name(f"{app_name}_{user_id}") - # knowledgebase is application specific + def delete(self, app_name: str, **kwargs): collection_name = format_collection_name(f"{app_name}") try: self.database_client.delete(collection_name=collection_name) @@ -193,13 +200,13 @@ class KnowledgebaseLocalDatabaseAdapter(BaseModel): database_client: BaseDatabase - def add(self, content: list[str], app_name: str, user_id: str, session_id: str): + def add(self, content: list[str], **kwargs): self.database_client.add(content) def query(self, query: str, app_name: str, user_id: str, **kwargs): return self.database_client.query(query, **kwargs) - def delete(self, app_name: str, user_id: str, session_id: str): + def delete(self, app_name: str, **kwargs): try: self.database_client.delete() logger.info(f"Successfully cleared local database for app {app_name}") @@ -229,24 +236,30 @@ def add( self, content: str | list[str] | TextIO | BinaryIO | bytes, app_name: str, - user_id: str, - session_id: str, **kwargs, ): - # collection_name = format_collection_name(f"{app_name}_{user_id}") collection_name = format_collection_name(f"{app_name}") + logger.debug( + f"Adding documents to Viking database: collection_name={collection_name}" + ) self.get_or_create_collection(collection_name) self.database_client.add(content, collection_name=collection_name, **kwargs) - def query(self, query: str, app_name: str, user_id: str, **kwargs): + def query(self, query: str, app_name: str, **kwargs): collection_name = format_collection_name(f"{app_name}") + logger.debug( + f"Querying Viking database: collection_name={collection_name} query={query}" + ) + + # FIXME(): fix here if not self.database_client.collection_exists(collection_name): raise ValueError(f"Collection {collection_name} does not exist") + return self.database_client.query( query, collection_name=collection_name, **kwargs ) - def delete(self, app_name: str, user_id: str, session_id: str): + def delete(self, app_name: str, **kwargs): # collection_name = format_collection_name(f"{app_name}_{user_id}") collection_name = format_collection_name(f"{app_name}") try: diff --git a/veadk/memory/long_term_memory.py b/veadk/memory/long_term_memory.py index e6e5d13a..af295a16 100644 --- a/veadk/memory/long_term_memory.py +++ b/veadk/memory/long_term_memory.py @@ -25,7 +25,6 @@ from google.genai import types from typing_extensions import override -from veadk.config import getenv from veadk.database import DatabaseFactory from veadk.utils.logger import get_logger @@ -40,20 +39,27 @@ def __init__( backend: Literal[ "local", "opensearch", "redis", "mysql", "viking" ] = "opensearch", - top_k: int = getenv("LONGTERM_MEMORY_TOP_K", 3), + top_k: int = 5, ): if backend == "viking": backend = "viking_mem" self.top_k = top_k self.backend = backend + logger.info( + f"Initializing long term memory: backend={self.backend} top_k={self.top_k}" + ) + self.db_client = DatabaseFactory.create( backend=backend, ) - logger.info(f"Long term memory backend is `{backend}`.") self.adapter = get_memory_adapter(backend)(database_client=self.db_client) + logger.info( + f"Initialized long term memory: db_client={self.db_client} adapter={self.adapter}" + ) + @override async def add_session_to_memory( self, @@ -63,6 +69,8 @@ async def add_session_to_memory( for event in session.events: if not event.content or not event.content.parts: continue + if not event.author == "user": # we only add user event to memory + continue message = event.content.model_dump(exclude_none=True, mode="json") if ( @@ -77,16 +85,30 @@ async def add_session_to_memory( session_id=session.id, ) + logger.info( + f"Added {len(event_list)} events to long term memory: app_name={session.app_name} user_id={session.user_id} session_id={session.id}" + ) + @override async def search_memory(self, *, app_name: str, user_id: str, query: str): + logger.info( + f"Searching long term memory: query={query} app_name={app_name} user_id={user_id}" + ) memory_chunks = self.adapter.query( query=query, app_name=app_name, user_id=user_id, ) if len(memory_chunks) == 0: + logger.info( + f"Found no memory chunks for query: {query} app_name={app_name} user_id={user_id}" + ) return SearchMemoryResponse() + logger.info( + f"Found {len(memory_chunks)} memory chunks for query: {query} app_name={app_name} user_id={user_id}" + ) + memory_events = [] for memory in memory_chunks: try: @@ -94,20 +116,27 @@ async def search_memory(self, *, app_name: str, user_id: str, query: str): try: text = memory_dict["parts"][0]["text"] role = memory_dict["role"] - except KeyError as e: - logger.error( - f"Memory content: {memory_dict}. Error parsing memory: {e}" + except KeyError as _: + # prevent not a standard text-based event + logger.warning( + f"Memory content: {memory_dict}. Skip return this memory." ) continue except json.JSONDecodeError: + # prevent the memory string is not dumped by `event` text = memory role = "user" + memory_events.append( MemoryEntry( author="user", content=types.Content(parts=[types.Part(text=text)], role=role), ) ) + + logger.info( + f"Return {len(memory_events)} memory events for query: {query} app_name={app_name} user_id={user_id}" + ) return SearchMemoryResponse(memories=memory_events) @override diff --git a/veadk/memory/memory_database_adapter.py b/veadk/memory/memory_database_adapter.py index d176377d..f17f0478 100644 --- a/veadk/memory/memory_database_adapter.py +++ b/veadk/memory/memory_database_adapter.py @@ -33,6 +33,14 @@ def format_collection_name(collection_name: str) -> str: return re.sub(r"[^a-z0-9_]", "", replaced_str).lower() +def build_index(**kwargs): + """ + Build the index name for the long-term memory. + """ + # TODO + ... + + def get_memory_adapter(backend: str): if backend == DatabaseBackend.REDIS: return MemoryKVDatabaseAdapter From 5206d7ef1642e51e0e4675ee835252d45e21d956 Mon Sep 17 00:00:00 2001 From: "fangyaozheng@bytedance.com" Date: Thu, 7 Aug 2025 21:03:39 +0800 Subject: [PATCH 2/6] fix index error --- tests/test_knowledgebase.py | 2 -- tests/test_long_term_memory.py | 2 +- veadk/knowledgebase/knowledgebase.py | 2 +- veadk/knowledgebase/knowledgebase_database_adapter.py | 6 +++--- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/test_knowledgebase.py b/tests/test_knowledgebase.py index ec4518ed..c7f33377 100644 --- a/tests/test_knowledgebase.py +++ b/tests/test_knowledgebase.py @@ -33,5 +33,3 @@ async def test_knowledgebase(): ) res = "".join(res_list) assert key in res, f"Test failed for backend local res is {res}" - assert key in res, f"Test failed for backend local res is {res}" - assert key in res, f"Test failed for backend local res is {res}" diff --git a/tests/test_long_term_memory.py b/tests/test_long_term_memory.py index b088a2ce..12532825 100644 --- a/tests/test_long_term_memory.py +++ b/tests/test_long_term_memory.py @@ -49,7 +49,7 @@ async def test_long_term_memory(): events=[ Event( invocation_id="test_invocation_id", - author=agent.name, + author="user", branch=None, content=types.Content( parts=[types.Part(text="My name is Alice.")], diff --git a/veadk/knowledgebase/knowledgebase.py b/veadk/knowledgebase/knowledgebase.py index 8d382d93..7acae933 100644 --- a/veadk/knowledgebase/knowledgebase.py +++ b/veadk/knowledgebase/knowledgebase.py @@ -59,7 +59,7 @@ def add( """ kwargs.pop("session_id", None) # remove session_id logger.info(f"Adding documents to knowledgebase: app_name={app_name}") - self.adapter.add(data, app_name, **kwargs) + self.adapter.add(data, app_name=app_name, **kwargs) def search(self, query: str, app_name: str, top_k: int = None) -> list[str]: """Retrieve documents similar to the query text in the vector database. diff --git a/veadk/knowledgebase/knowledgebase_database_adapter.py b/veadk/knowledgebase/knowledgebase_database_adapter.py index bd493556..b0de2854 100644 --- a/veadk/knowledgebase/knowledgebase_database_adapter.py +++ b/veadk/knowledgebase/knowledgebase_database_adapter.py @@ -200,10 +200,10 @@ class KnowledgebaseLocalDatabaseAdapter(BaseModel): database_client: BaseDatabase - def add(self, content: list[str], **kwargs): - self.database_client.add(content) + def add(self, data: list[str], **kwargs): + self.database_client.add(data) - def query(self, query: str, app_name: str, user_id: str, **kwargs): + def query(self, query: str, **kwargs): return self.database_client.query(query, **kwargs) def delete(self, app_name: str, **kwargs): From ce738622726d9fb4c44b03e6c762afed0dfb302f Mon Sep 17 00:00:00 2001 From: "fangyaozheng@bytedance.com" Date: Fri, 8 Aug 2025 09:42:03 +0800 Subject: [PATCH 3/6] reconstruct database adapter --- veadk/database/database_adapter.py | 248 ++++++++++++++++ veadk/database/viking/viking_database.py | 9 +- veadk/knowledgebase/knowledgebase.py | 30 +- .../knowledgebase_database_adapter.py | 272 ------------------ veadk/memory/long_term_memory.py | 94 +++--- veadk/memory/memory_database_adapter.py | 243 ---------------- 6 files changed, 315 insertions(+), 581 deletions(-) create mode 100644 veadk/database/database_adapter.py delete mode 100644 veadk/knowledgebase/knowledgebase_database_adapter.py delete mode 100644 veadk/memory/memory_database_adapter.py diff --git a/veadk/database/database_adapter.py b/veadk/database/database_adapter.py new file mode 100644 index 00000000..05406212 --- /dev/null +++ b/veadk/database/database_adapter.py @@ -0,0 +1,248 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import BinaryIO, TextIO + +from pydantic import BaseModel, ConfigDict + +from veadk.database.base_database import BaseDatabase +from veadk.database.kv.redis_database import RedisDatabase +from veadk.database.local_database import LocalDataBase +from veadk.database.relational.mysql_database import MysqlDatabase +from veadk.database.vector.opensearch_vector_database import OpenSearchVectorDatabase +from veadk.database.viking.viking_database import VikingDatabase +from veadk.database.viking.viking_memory_db import VikingDatabaseMemory +from veadk.utils.logger import get_logger + +logger = get_logger(__name__) + + +class KVDatabaseAdapter(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + client: RedisDatabase + + def add(self, data: list[str], index: str): + logger.debug(f"Adding documents to Redis database: index={index}") + + try: + for _data in data: + self.client.add(key=index, data=_data) + logger.debug(f"Added {len(data)} texts to Redis database: index={index}") + except Exception as e: + logger.error( + f"Failed to add data to Redis database: index={index} error={e}" + ) + raise e + + def query(self, query: str, index: str, top_k: int = 0) -> list[str]: + logger.debug(f"Querying Redis database: index={index} query={query}") + + # ignore top_k, as KV search only return one result + _ = top_k + + try: + result = self.client.query(key=index, query=query) + return result + except Exception as e: + logger.error(f"Failed to search from Redis: index={index} error={e}") + raise e + + +class RelationalDatabaseAdapter(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + client: MysqlDatabase + + def create_table(self, table_name: str): + logger.debug(f"Creating table for SQL database: table_name={table_name}") + + sql = f""" + CREATE TABLE `{table_name}` ( + `id` BIGINT AUTO_INCREMENT PRIMARY KEY, + `data` TEXT NOT NULL, + `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) ENGINE=InnoDB DEFAULT CHARSET={self.client.config.charset}; + """ + self.client.add(sql) + + def add(self, data: list[str], index: str): + logger.debug( + f"Adding documents to SQL database: table_name={index} data_len={len(data)}" + ) + + if not self.client.table_exists(index): + logger.warning(f"Table {index} does not exist, creating...") + self.create_table(index) + + for _data in data: + sql = f""" + INSERT INTO `{index}` (`data`) + VALUES (%s); + """ + self.client.add(sql, params=(_data,)) + logger.debug(f"Added {len(data)} texts to table {index}.") + + def query(self, query: str, index: str, top_k: int) -> list[str]: + logger.debug( + f"Querying SQL database: table_name={index} query={query} top_k={top_k}" + ) + + if not self.client.table_exists(index): + logger.warning( + f"Querying SQL database, but table `{index}` does not exist, returning empty list." + ) + return [] + + sql = f""" + SELECT `data` FROM `{index}` ORDER BY `created_at` DESC LIMIT {top_k}; + """ + results = self.client.query(sql) + + return [item["data"] for item in results] + + +class VectorDatabaseAdapter(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + client: OpenSearchVectorDatabase + + def _validate_index(self, index: str): + # TODO + pass + + def add(self, data: list[str], index: str): + self._validate_index(index) + + logger.debug( + f"Adding documents to vector database: index={index} data_len={len(data)}" + ) + + self.client.add(data, collection_name=index) + + def query(self, query: str, index: str, top_k: int) -> list[str]: + self._validate_index(index) + + logger.debug( + f"Querying vector database: collection_name={index} query={query} top_k={top_k}" + ) + + return self.client.query( + query=query, + collection_name=index, + top_k=top_k, + ) + + +class VikingDatabaseAdapter(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + client: VikingDatabase + + def _validate_index(self, index: str): + # TODO + pass + + def get_or_create_collection(self, collection_name: str): + if not self.client.collection_exists(collection_name): + self.client.create_collection(collection_name) + + count = 0 + while not self.client.collection_exists(collection_name): + time.sleep(1) + count += 1 + if count > 50: + raise TimeoutError( + f"Collection {collection_name} not created after 50 seconds" + ) + + def add( + self, data: str | list[str] | TextIO | BinaryIO | bytes, index: str, **kwargs + ): + self._validate_index(index) + + logger.debug(f"Adding documents to Viking database: collection_name={index}") + self.get_or_create_collection(index) + self.client.add(data, collection_name=index, **kwargs) + + def query(self, query: str, index: str, top_k: int) -> list[str]: + self._validate_index(index) + + logger.debug(f"Querying Viking database: collection_name={index} query={query}") + + # FIXME(): maybe do not raise, but just return [] + if not self.client.collection_exists(index): + raise ValueError(f"Collection {index} does not exist") + + return self.client.query(query, collection_name=index, top_k=top_k) + + +class VikingDatabaseMemoryAdapter(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + client: VikingDatabaseMemory + + def _validate_index(self, index: str): + # TODO + pass + + def add(self, data: list[str], index: str): + self._validate_index(index) + + logger.debug( + f"Adding documents to Viking database memory: collection_name={index} data_len={len(data)}" + ) + + self.client.add(data, collection_name=index) + + def query(self, query: str, index: str, top_k: int): + self._validate_index(index) + + logger.debug( + f"Querying Viking database memory: collection_name={index} query={query} top_k={top_k}" + ) + + result = self.client.query(query, collection_name=index, top_k=top_k) + return result + + +class LocalDatabaseAdapter(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + client: LocalDataBase + + def add(self, data: list[str], **kwargs): + self.client.add(data) + + def query(self, query: str, **kwargs): + return self.client.query(query, **kwargs) + + +MAPPING = { + RedisDatabase: KVDatabaseAdapter, + MysqlDatabase: RelationalDatabaseAdapter, + LocalDataBase: LocalDatabaseAdapter, + VikingDatabase: VikingDatabaseAdapter, + OpenSearchVectorDatabase: VectorDatabaseAdapter, + VikingDatabaseMemory: VikingDatabaseMemoryAdapter, +} + + +def get_knowledgebase_database_adapter(database_client: BaseDatabase): + return MAPPING[type(database_client)](database_client=database_client) + + +def get_long_term_memory_database_adapter(database_client: BaseDatabase): + return MAPPING[type(database_client)](database_client=database_client) diff --git a/veadk/database/viking/viking_database.py b/veadk/database/viking/viking_database.py index e4654a7b..18a2b831 100644 --- a/veadk/database/viking/viking_database.py +++ b/veadk/database/viking/viking_database.py @@ -215,7 +215,12 @@ def _add_doc(self, collection_name: str, tos_url: str, doc_id: str, **kwargs: An return doc_id - def add(self, data: str | list[str] | TextIO | BinaryIO | bytes, **kwargs: Any): + def add( + self, + data: str | list[str] | TextIO | BinaryIO | bytes, + collection_name: str, + **kwargs, + ): """ Args: data: str, file path or file stream: Both file or file.read() are acceptable. @@ -226,8 +231,6 @@ def add(self, data: str | list[str] | TextIO | BinaryIO | bytes, **kwargs: Any): "doc_id": "", } """ - collection_name = kwargs.get("collection_name") - assert collection_name is not None, "collection_name is required" status, tos_url = self._upload_to_tos(data=data, **kwargs) if status != 200: diff --git a/veadk/knowledgebase/knowledgebase.py b/veadk/knowledgebase/knowledgebase.py index 7acae933..a54001c5 100644 --- a/veadk/knowledgebase/knowledgebase.py +++ b/veadk/knowledgebase/knowledgebase.py @@ -14,14 +14,17 @@ from typing import BinaryIO, Literal, TextIO +from veadk.database.database_adapter import get_knowledgebase_database_adapter from veadk.database.database_factory import DatabaseFactory from veadk.utils.logger import get_logger -from .knowledgebase_database_adapter import get_knowledgebase_adapter - logger = get_logger(__name__) +def build_knowledgebase_index(app_name: str): + return f"{app_name}" + + class KnowledgeBase: def __init__( self, @@ -37,9 +40,7 @@ def __init__( self.top_k = top_k self.db_client = DatabaseFactory.create(backend=backend, config=db_config) - self.adapter = get_knowledgebase_adapter(backend)( - database_client=self.db_client - ) + self.adapter = get_knowledgebase_database_adapter(self.db_client) logger.info( f"Initialized knowledgebase: db_client={self.db_client} adapter={self.adapter}" @@ -53,23 +54,18 @@ def add( ): """ Add documents to the vector database. - You can only upload files or file characters when the adapter type you use is vikingdb. + You can only upload files or file characters when the adapter type used is vikingdb. In addition, if you upload data of the bytes type, for example, if you read the file stream of a pdf, then you need to pass an additional parameter file_ext = '.pdf'. """ - kwargs.pop("session_id", None) # remove session_id - logger.info(f"Adding documents to knowledgebase: app_name={app_name}") - self.adapter.add(data, app_name=app_name, **kwargs) + # TODO: add check for data type + ... - def search(self, query: str, app_name: str, top_k: int = None) -> list[str]: - """Retrieve documents similar to the query text in the vector database. + index = build_knowledgebase_index(app_name) + logger.info(f"Adding documents to knowledgebase: index={index}") + self.adapter.add(data=data, index=index) - Args: - query (str): The query text to be retrieved (e.g., "Who proposed the Turing machine model?") - - Returns: - list[str]: A list of the top most similar document contents retrieved (sorted by vector similarity) - """ + def search(self, query: str, app_name: str, top_k: int = None) -> list[str]: top_k = self.top_k if top_k is None else top_k logger.info( diff --git a/veadk/knowledgebase/knowledgebase_database_adapter.py b/veadk/knowledgebase/knowledgebase_database_adapter.py deleted file mode 100644 index b0de2854..00000000 --- a/veadk/knowledgebase/knowledgebase_database_adapter.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Knowledgebase may use different databases, so we need to create -an adapter to abstract the database operations. -""" - -import re -import time -from typing import BinaryIO, TextIO - -from pydantic import BaseModel, ConfigDict - -from veadk.database.base_database import BaseDatabase -from veadk.database.database_factory import DatabaseBackend -from veadk.utils.logger import get_logger - -logger = get_logger(__name__) - - -def format_collection_name(collection_name: str) -> str: - replaced_str = re.sub(r"[- ]", "_", collection_name) - return re.sub(r"[^a-z0-9_]", "", replaced_str).lower() - - -def build_index(**kwargs): - """ - Build the index name for the knowledgebase. - """ - # TODO - ... - - -def get_knowledgebase_adapter(backend: str): - if backend == DatabaseBackend.REDIS: - return KnowledgebaseKVDatabaseAdapter - elif backend == DatabaseBackend.MYSQL: - return KnowledgebaseRelationalDatabaseAdapter - elif backend == DatabaseBackend.OPENSEARCH: - return KnowledgebaseVectorDatabaseAdapter - elif backend == DatabaseBackend.LOCAL: - return KnowledgebaseLocalDatabaseAdapter - elif backend == DatabaseBackend.VIKING: - return KnowledgebaseVikingDatabaseAdapter - else: - raise ValueError(f"Unknown backend: {backend}") - - -class KnowledgebaseKVDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def add(self, content: list[str], app_name: str, **kwargs): - key = f"{app_name}" - logger.debug(f"Adding documents to Redis database: key={key}") - - try: - for _content in content: - self.database_client.add(key, _content) - logger.debug(f"Added {len(content)} texts to Redis database: key={key}") - except Exception as e: - logger.error(f"Failed to add texts to Redis database key `{key}`: {e}") - raise e - - def query(self, query: str, app_name: str, **kwargs): - key = f"{app_name}" - top_k = 10 - logger.debug(f"Querying Redis database: key={key} query={query}") - - try: - result = self.database_client.query(key, query) - return result[-top_k:] - except Exception as e: - logger.error(f"Failed to search from Redis list key '{key}': {e}") - raise e - - def delete(self, app_name: str, **kwargs): - try: - key = f"{app_name}" - self.database_client.delete(key=key) - logger.info(f"Successfully deleted data for app {app_name}") - except Exception as e: - logger.error(f"Failed to delete data: {e}") - raise e - - -class KnowledgebaseRelationalDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def create_table(self, table_name: str): - logger.debug(f"Creating table for SQL database: table_name={table_name}") - - sql = f""" - CREATE TABLE `{table_name}` ( - `id` BIGINT AUTO_INCREMENT PRIMARY KEY, - `data` TEXT NOT NULL, - `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) ENGINE=InnoDB DEFAULT CHARSET={self.database_client.config.charset}; - """ - self.database_client.add(sql) - - def add(self, content: list[str], app_name: str, **kwargs): - table = app_name - logger.debug( - f"Adding documents to SQL database: table_name={table} content_len={len(content)}" - ) - - if not self.database_client.table_exists(table): - logger.warning(f"Table {table} does not exist, creating...") - self.create_table(table) - - for _content in content: - sql = f""" - INSERT INTO `{table}` (`data`) - VALUES (%s); - """ - self.database_client.add(sql, params=(_content,)) - logger.debug(f"Added {len(content)} texts to table {table}.") - - def query(self, query: str, app_name: str, **kwargs): - """Search content from table app_name.""" - table = app_name - top_k = 10 - logger.debug( - f"Querying SQL database: table_name={table} query={query} top_k={top_k}" - ) - - if not self.database_client.table_exists(table): - logger.warning( - f"Querying SQL database, but table `{table}` does not exist, returning empty list." - ) - return [] - - sql = f""" - SELECT `data` FROM `{table}` ORDER BY `created_at` DESC LIMIT {top_k}; - """ - results = self.database_client.query(sql) - return [item["data"] for item in results] - - def delete(self, app_name: str, **kwargs): - table = app_name - try: - self.database_client.delete(table=table) - logger.info(f"Successfully deleted data from table {app_name}") - except Exception as e: - logger.error(f"Failed to delete data: {e}") - raise e - - -class KnowledgebaseVectorDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def add(self, content: list[str], app_name: str, **kwargs): - logger.debug( - f"Adding documents to vector database: app_name={app_name} content_len={len(content)}" - ) - collection_name = format_collection_name(f"{app_name}") - self.database_client.add(content, collection_name=collection_name) - - def query(self, query: str, app_name: str, **kwargs): - logger.debug( - f"Querying vector database: collection_name={app_name} query={query}" - ) - collection_name = format_collection_name(f"{app_name}") - return self.database_client.query( - query, collection_name=collection_name, **kwargs - ) - - def delete(self, app_name: str, **kwargs): - collection_name = format_collection_name(f"{app_name}") - try: - self.database_client.delete(collection_name=collection_name) - logger.info( - f"Successfully deleted vector database collection for app {app_name}" - ) - except Exception as e: - logger.error(f"Failed to delete vector database collection: {e}") - raise e - - -class KnowledgebaseLocalDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def add(self, data: list[str], **kwargs): - self.database_client.add(data) - - def query(self, query: str, **kwargs): - return self.database_client.query(query, **kwargs) - - def delete(self, app_name: str, **kwargs): - try: - self.database_client.delete() - logger.info(f"Successfully cleared local database for app {app_name}") - except Exception as e: - logger.error(f"Failed to clear local database: {e}") - raise e - - -class KnowledgebaseVikingDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def get_or_create_collection(self, collection_name: str): - if not self.database_client.collection_exists(collection_name): - self.database_client.create_collection(collection_name) - count = 0 - while not self.database_client.collection_exists(collection_name): - time.sleep(1) - count += 1 - if count > 50: - raise TimeoutError( - f"Collection {collection_name} not created after 50 seconds" - ) - - def add( - self, - content: str | list[str] | TextIO | BinaryIO | bytes, - app_name: str, - **kwargs, - ): - collection_name = format_collection_name(f"{app_name}") - logger.debug( - f"Adding documents to Viking database: collection_name={collection_name}" - ) - self.get_or_create_collection(collection_name) - self.database_client.add(content, collection_name=collection_name, **kwargs) - - def query(self, query: str, app_name: str, **kwargs): - collection_name = format_collection_name(f"{app_name}") - logger.debug( - f"Querying Viking database: collection_name={collection_name} query={query}" - ) - - # FIXME(): fix here - if not self.database_client.collection_exists(collection_name): - raise ValueError(f"Collection {collection_name} does not exist") - - return self.database_client.query( - query, collection_name=collection_name, **kwargs - ) - - def delete(self, app_name: str, **kwargs): - # collection_name = format_collection_name(f"{app_name}_{user_id}") - collection_name = format_collection_name(f"{app_name}") - try: - self.database_client.delete(collection_name=collection_name) - logger.info( - f"Successfully deleted vector database collection for app {app_name}" - ) - except Exception as e: - logger.error(f"Failed to delete vector database collection: {e}") - raise e diff --git a/veadk/memory/long_term_memory.py b/veadk/memory/long_term_memory.py index af295a16..204fccf1 100644 --- a/veadk/memory/long_term_memory.py +++ b/veadk/memory/long_term_memory.py @@ -16,6 +16,7 @@ import json from typing import Literal +from google.adk.events.event import Event from google.adk.memory.base_memory_service import ( BaseMemoryService, SearchMemoryResponse, @@ -26,13 +27,16 @@ from typing_extensions import override from veadk.database import DatabaseFactory +from veadk.database.database_adapter import get_long_term_memory_database_adapter from veadk.utils.logger import get_logger -from .memory_database_adapter import get_memory_adapter - logger = get_logger(__name__) +def build_long_term_memory_index(app_name: str, user_id: str): + return f"{app_name}_{user_id}" + + class LongTermMemory(BaseMemoryService): def __init__( self, @@ -53,62 +57,68 @@ def __init__( self.db_client = DatabaseFactory.create( backend=backend, ) - - self.adapter = get_memory_adapter(backend)(database_client=self.db_client) + self.adapter = get_long_term_memory_database_adapter(self.db_client) logger.info( f"Initialized long term memory: db_client={self.db_client} adapter={self.adapter}" ) + def _filter_and_convert_events(self, events: list[Event]) -> list[str]: + final_events = [] + for event in events: + # filter: bad event + if not event.content or not event.content.parts: + continue + + # filter: only add user event to memory to enhance retrieve performance + if not event.author == "user": + continue + + # filter: discard function call and function_response + if not event.content.parts[0].text: + continue + + # convert: to string-format for storage + message = event.content.model_dump(exclude_none=True, mode="json") + final_events.append(json.dumps(message)) + return final_events + @override async def add_session_to_memory( self, session: Session, ): - event_list = [] - for event in session.events: - if not event.content or not event.content.parts: - continue - if not event.author == "user": # we only add user event to memory - continue + event_strings = self._filter_and_convert_events(session.events) + index = build_long_term_memory_index(session.app_name, session.user_id) - message = event.content.model_dump(exclude_none=True, mode="json") - if ( - "text" not in message["parts"][0] - ): # remove function_call & function_resp - continue - event_list.append(json.dumps(message)) - self.adapter.add( - event_list, - app_name=session.app_name, - user_id=session.user_id, - session_id=session.id, + logger.info( + f"Adding {len(event_strings)} events to long term memory: index={index}" ) + self.adapter.add(data=event_strings, index=index) + logger.info( - f"Added {len(event_list)} events to long term memory: app_name={session.app_name} user_id={session.user_id} session_id={session.id}" + f"Added {len(event_strings)} events to long term memory: index={index}" ) @override async def search_memory(self, *, app_name: str, user_id: str, query: str): - logger.info( - f"Searching long term memory: query={query} app_name={app_name} user_id={user_id}" - ) - memory_chunks = self.adapter.query( - query=query, - app_name=app_name, - user_id=user_id, - ) - if len(memory_chunks) == 0: - logger.info( - f"Found no memory chunks for query: {query} app_name={app_name} user_id={user_id}" - ) - return SearchMemoryResponse() + index = build_long_term_memory_index(app_name, user_id) logger.info( - f"Found {len(memory_chunks)} memory chunks for query: {query} app_name={app_name} user_id={user_id}" + f"Searching long term memory: query={query} index={index} top_k={self.top_k}" ) + memory_chunks = self.adapter.query(query=query, index=index, top_k=self.top_k) + + # if len(memory_chunks) == 0: + # logger.info(f"Found no memory chunks for query: {query} index={index}") + # return SearchMemoryResponse() + + # logger.info( + # f"Found {len(memory_chunks)} memory chunks for query: {query} index={index}" + # ) + memory_events = [] for memory in memory_chunks: try: @@ -123,7 +133,7 @@ async def search_memory(self, *, app_name: str, user_id: str, query: str): ) continue except json.JSONDecodeError: - # prevent the memory string is not dumped by `event` + # prevent the memory string is not dumped by `Event` class text = memory role = "user" @@ -135,14 +145,6 @@ async def search_memory(self, *, app_name: str, user_id: str, query: str): ) logger.info( - f"Return {len(memory_events)} memory events for query: {query} app_name={app_name} user_id={user_id}" + f"Return {len(memory_events)} memory events for query: {query} index={index}" ) return SearchMemoryResponse(memories=memory_events) - - @override - async def delete_memory(self, *, app_name: str, user_id: str): - self.adapter.delete( - app_name=app_name, - user_id=user_id, - session_id="", # session_id is not used in the adapter delete method - ) diff --git a/veadk/memory/memory_database_adapter.py b/veadk/memory/memory_database_adapter.py deleted file mode 100644 index f17f0478..00000000 --- a/veadk/memory/memory_database_adapter.py +++ /dev/null @@ -1,243 +0,0 @@ -# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Longterm memory may use different databases, so we need to create -an adapter to abstract the database operations. -""" - -import re - -from pydantic import BaseModel, ConfigDict - -from veadk.database.base_database import BaseDatabase -from veadk.database.database_factory import DatabaseBackend -from veadk.utils.logger import get_logger - -logger = get_logger(__name__) - - -def format_collection_name(collection_name: str) -> str: - replaced_str = re.sub(r"[- ]", "_", collection_name) - return re.sub(r"[^a-z0-9_]", "", replaced_str).lower() - - -def build_index(**kwargs): - """ - Build the index name for the long-term memory. - """ - # TODO - ... - - -def get_memory_adapter(backend: str): - if backend == DatabaseBackend.REDIS: - return MemoryKVDatabaseAdapter - elif backend == DatabaseBackend.MYSQL: - return MemoryRelationalDatabaseAdapter - elif backend == DatabaseBackend.OPENSEARCH: - return MemoryVectorDatabaseAdapter - elif backend == DatabaseBackend.LOCAL: - return MemoryLocalDatabaseAdapter - elif backend == DatabaseBackend.VIKING_MEM: - return MemoryVikingDBAdapter - else: - raise ValueError(f"Unknown backend: {backend}") - - -class MemoryKVDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def add(self, content: list[str], app_name: str, user_id: str, session_id: str): - """Add texts to Redis. - - Key: app_name - Field: app_name:user_id - Value: text in List - """ - key = f"{app_name}:{user_id}" - - try: - for _content in content: - self.database_client.add(key, _content) - logger.debug( - f"Successfully added {len(content)} texts to Redis list key `{key}`." - ) - except Exception as e: - logger.error(f"Failed to add texts to Redis list key `{key}`: {e}") - raise e - - def query(self, query: str, app_name: str, user_id: str): - key = f"{app_name}:{user_id}" - top_k = 10 - - try: - result = self.database_client.query(key, query) - # Get latest top_k records. - # The data is stored in a Redis list, and the latest data is at the end of the list. - return result[-top_k:] - except Exception as e: - logger.error(f"Failed to search from Redis list key '{key}': {e}") - raise e - - def delete(self, app_name: str, user_id: str, session_id: str): - try: - self.database_client.delete( - app_name=app_name, user_id=user_id, session_id=session_id - ) - logger.info( - f"Successfully deleted memory data for app {app_name}, user {user_id}, session {session_id}" - ) - except Exception as e: - logger.error(f"Failed to delete memory data: {e}") - raise e - - -class MemoryRelationalDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def create_table(self, table_name: str): - sql = f""" - CREATE TABLE `{table_name}` ( - `id` BIGINT AUTO_INCREMENT PRIMARY KEY, - `data` TEXT NOT NULL, - `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) ENGINE=InnoDB DEFAULT CHARSET={self.database_client.config.charset}; - """ - self.database_client.add(sql) - - def add(self, content: list[str], app_name: str, user_id: str, session_id: str): - table = f"{app_name}_{user_id}" - - if not self.database_client.table_exists(table): - logger.warning(f"Table {table} does not exist, creating...") - self.create_table(table) - - for _content in content: - sql = f""" - INSERT INTO `{table}` (`data`) - VALUES (%s); - """ - self.database_client.add(sql, params=(_content,)) - logger.info(f"Successfully added {len(content)} texts to table {table}.") - - def query(self, query: str, app_name: str, user_id: str): - """Search content from table app_name_user_id.""" - top_k = 10 - table = f"{app_name}_{user_id}" - - if not self.database_client.table_exists(table): - logger.warning( - f"querying {query}, but table `{table}` does not exist, returning empty list." - ) - return [] - - sql = f""" - SELECT `data` FROM `{table}` ORDER BY `created_at` DESC LIMIT {top_k}; - """ - results = self.database_client.query(sql) - return [item["data"] for item in results] - - def delete(self, app_name: str, user_id: str, session_id: str): - table = f"{app_name}_{user_id}" - try: - self.database_client.delete(table=table) - logger.info(f"Successfully deleted memory data from table {table}") - except Exception as e: - logger.error(f"Failed to delete memory data: {e}") - raise e - - -class MemoryVectorDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def add(self, content: list[str], app_name: str, user_id: str, session_id: str): - collection_name = format_collection_name(f"{app_name}_{user_id}") - self.database_client.add(content, collection_name=collection_name) - - def query(self, query: str, app_name: str, user_id: str): - collection_name = format_collection_name(f"{app_name}_{user_id}") - top_k = 10 - return self.database_client.query( - query, collection_name=collection_name, top_k=top_k - ) - - def delete(self, app_name: str, user_id: str, session_id: str): - collection_name = format_collection_name(f"{app_name}_{user_id}") - try: - self.database_client.delete(collection_name=collection_name) - logger.info( - f"Successfully deleted vector memory database collection for app {app_name}" - ) - except Exception as e: - logger.error(f"Failed to delete vector memory database collection: {e}") - raise e - - -class MemoryLocalDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def add(self, content: list[str], app_name: str, user_id: str, session_id: str): - self.database_client.add(content) - - def query(self, query: str, app_name: str, user_id: str): - return self.database_client.query(query) - - def delete(self, app_name: str, user_id: str, session_id: str): - try: - self.database_client.delete() - logger.info( - f"Successfully cleared local memory database for app {app_name}" - ) - except Exception as e: - logger.error(f"Failed to clear local memory database: {e}") - raise e - - -class MemoryVikingDBAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def add( - self, content: list[str], app_name: str, user_id: str, session_id: str, **kwargs - ): - kwargs.pop("user_id", None) - - collection_name = format_collection_name(f"{app_name}_{user_id}") - self.database_client.add( - content, collection_name=collection_name, user_id=user_id, **kwargs - ) - - def query(self, query: str, app_name: str, user_id: str, **kwargs): - kwargs.pop("user_id", None) - - collection_name = format_collection_name(f"{app_name}_{user_id}") - result = self.database_client.query( - query, collection_name=collection_name, user_id=user_id, **kwargs - ) - return result - - def delete(self, app_name: str, user_id: str, session_id: str): - # collection_name = format_collection_name(f"{app_name}_{user_id}") - # todo: delete viking memory db - ... From f7f25cbb9e19196d6c59279e51b87edf39fb1efd Mon Sep 17 00:00:00 2001 From: "fangyaozheng@bytedance.com" Date: Fri, 8 Aug 2025 10:25:41 +0800 Subject: [PATCH 4/6] update naming --- veadk/database/database_adapter.py | 18 +++++++++++------- veadk/database/viking/viking_memory_db.py | 7 ++++++- veadk/knowledgebase/knowledgebase.py | 13 +++---------- veadk/memory/long_term_memory.py | 5 ++++- 4 files changed, 24 insertions(+), 19 deletions(-) diff --git a/veadk/database/database_adapter.py b/veadk/database/database_adapter.py index 05406212..60cbedd9 100644 --- a/veadk/database/database_adapter.py +++ b/veadk/database/database_adapter.py @@ -23,7 +23,7 @@ from veadk.database.relational.mysql_database import MysqlDatabase from veadk.database.vector.opensearch_vector_database import OpenSearchVectorDatabase from veadk.database.viking.viking_database import VikingDatabase -from veadk.database.viking.viking_memory_db import VikingDatabaseMemory +from veadk.database.viking.viking_memory_db import VikingMemoryDatabase from veadk.utils.logger import get_logger logger = get_logger(__name__) @@ -39,7 +39,7 @@ def add(self, data: list[str], index: str): try: for _data in data: - self.client.add(key=index, data=_data) + self.client.add(key=index, value=_data) logger.debug(f"Added {len(data)} texts to Redis database: index={index}") except Exception as e: logger.error( @@ -84,7 +84,7 @@ def add(self, data: list[str], index: str): ) if not self.client.table_exists(index): - logger.warning(f"Table {index} does not exist, creating...") + logger.warning(f"Table {index} does not exist, creating a new table.") self.create_table(index) for _data in data: @@ -133,6 +133,7 @@ def add(self, data: list[str], index: str): self.client.add(data, collection_name=index) def query(self, query: str, index: str, top_k: int) -> list[str]: + # FIXME: confirm self._validate_index(index) logger.debug( @@ -159,6 +160,7 @@ def get_or_create_collection(self, collection_name: str): if not self.client.collection_exists(collection_name): self.client.create_collection(collection_name) + # FIXME count = 0 while not self.client.collection_exists(collection_name): time.sleep(1) @@ -174,6 +176,7 @@ def add( self._validate_index(index) logger.debug(f"Adding documents to Viking database: collection_name={index}") + self.get_or_create_collection(index) self.client.add(data, collection_name=index, **kwargs) @@ -189,22 +192,23 @@ def query(self, query: str, index: str, top_k: int) -> list[str]: return self.client.query(query, collection_name=index, top_k=top_k) -class VikingDatabaseMemoryAdapter(BaseModel): +class VikingMemoryDatabaseAdapter(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - client: VikingDatabaseMemory + client: VikingMemoryDatabase def _validate_index(self, index: str): # TODO pass - def add(self, data: list[str], index: str): + def add(self, data: list[str], index: str, **kwargs): self._validate_index(index) logger.debug( f"Adding documents to Viking database memory: collection_name={index} data_len={len(data)}" ) + # TODO: parse user_id self.client.add(data, collection_name=index) def query(self, query: str, index: str, top_k: int): @@ -236,7 +240,7 @@ def query(self, query: str, **kwargs): LocalDataBase: LocalDatabaseAdapter, VikingDatabase: VikingDatabaseAdapter, OpenSearchVectorDatabase: VectorDatabaseAdapter, - VikingDatabaseMemory: VikingDatabaseMemoryAdapter, + VikingMemoryDatabase: VikingMemoryDatabaseAdapter, } diff --git a/veadk/database/viking/viking_memory_db.py b/veadk/database/viking/viking_memory_db.py index b5c23430..d1ed37ad 100644 --- a/veadk/database/viking/viking_memory_db.py +++ b/veadk/database/viking/viking_memory_db.py @@ -34,6 +34,7 @@ logger = get_logger(__name__) +# FIXME class VikingMemConfig(BaseModel): volcengine_ak: Optional[str] = Field( default=getenv("VOLCENGINE_ACCESS_KEY"), @@ -53,6 +54,7 @@ class VikingMemConfig(BaseModel): ) +# ======= adapted from ... ======= class VikingDBMemoryException(Exception): def __init__(self, code, request_id, message=None): self.code = code @@ -363,7 +365,10 @@ def format_milliseconds(timestamp_ms): return dt.strftime("%Y%m%d %H:%M:%S") -class VikingDatabaseMemory(BaseModel, BaseDatabase): +# ======= adapted from ... ======= + + +class VikingMemoryDatabase(BaseModel, BaseDatabase): config: VikingMemConfig = Field( default_factory=VikingMemConfig, description="VikingDB configuration", diff --git a/veadk/knowledgebase/knowledgebase.py b/veadk/knowledgebase/knowledgebase.py index a54001c5..a95670fa 100644 --- a/veadk/knowledgebase/knowledgebase.py +++ b/veadk/knowledgebase/knowledgebase.py @@ -22,7 +22,7 @@ def build_knowledgebase_index(app_name: str): - return f"{app_name}" + return f"veadk_kb_{app_name}" class KnowledgeBase: @@ -62,7 +62,9 @@ def add( ... index = build_knowledgebase_index(app_name) + logger.info(f"Adding documents to knowledgebase: index={index}") + self.adapter.add(data=data, index=index) def search(self, query: str, app_name: str, top_k: int = None) -> list[str]: @@ -75,12 +77,3 @@ def search(self, query: str, app_name: str, top_k: int = None) -> list[str]: if len(result) == 0: logger.warning(f"No documents found in knowledgebase. Query: {query}") return result - - def delete(self, app_name: str, user_id: str, session_id: str): - """Delete documents in the vector database. - Args: - app_name (str): The name of the application - user_id (str): The user ID - session_id (str): The session ID - """ - self.adapter.delete(app_name=app_name, user_id=user_id, session_id=session_id) diff --git a/veadk/memory/long_term_memory.py b/veadk/memory/long_term_memory.py index 204fccf1..7f63a17d 100644 --- a/veadk/memory/long_term_memory.py +++ b/veadk/memory/long_term_memory.py @@ -74,12 +74,13 @@ def _filter_and_convert_events(self, events: list[Event]) -> list[str]: if not event.author == "user": continue - # filter: discard function call and function_response + # filter: discard function call and function response if not event.content.parts[0].text: continue # convert: to string-format for storage message = event.content.model_dump(exclude_none=True, mode="json") + final_events.append(json.dumps(message)) return final_events @@ -95,6 +96,7 @@ async def add_session_to_memory( f"Adding {len(event_strings)} events to long term memory: index={index}" ) + # check if viking memory database, should give a user id: if/else self.adapter.add(data=event_strings, index=index) logger.info( @@ -109,6 +111,7 @@ async def search_memory(self, *, app_name: str, user_id: str, query: str): f"Searching long term memory: query={query} index={index} top_k={self.top_k}" ) + # user id if viking memory db memory_chunks = self.adapter.query(query=query, index=index, top_k=self.top_k) # if len(memory_chunks) == 0: From 66ffe4fba7919938d4e94f2093646dc95005d1b6 Mon Sep 17 00:00:00 2001 From: "hanzhi.421" Date: Fri, 8 Aug 2025 16:55:22 +0800 Subject: [PATCH 5/6] fix: fix todo --- veadk/database/database_adapter.py | 67 ++++++++++++++++------- veadk/database/database_factory.py | 6 +- veadk/database/viking/viking_database.py | 21 ++++++- veadk/database/viking/viking_memory_db.py | 47 ++++++++-------- veadk/knowledgebase/knowledgebase.py | 11 +++- veadk/memory/long_term_memory.py | 22 ++++---- 6 files changed, 111 insertions(+), 63 deletions(-) diff --git a/veadk/database/database_adapter.py b/veadk/database/database_adapter.py index 60cbedd9..d45d9ca3 100644 --- a/veadk/database/database_adapter.py +++ b/veadk/database/database_adapter.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import re import time from typing import BinaryIO, TextIO @@ -120,8 +120,19 @@ class VectorDatabaseAdapter(BaseModel): client: OpenSearchVectorDatabase def _validate_index(self, index: str): - # TODO - pass + """ + Verify whether the string conforms to the naming rules of index_name in OpenSearch. + https://docs.opensearch.org/2.8/api-reference/index-apis/create-index/ + """ + if not ( + isinstance(index, str) + and not index.startswith(("_", "-")) + and index.islower() + and re.match(r"^[a-z0-9_\-.]+$", index) + ): + raise ValueError( + "The index name does not conform to the naming rules of OpenSearch" + ) def add(self, data: list[str], index: str): self._validate_index(index) @@ -133,9 +144,6 @@ def add(self, data: list[str], index: str): self.client.add(data, collection_name=index) def query(self, query: str, index: str, top_k: int) -> list[str]: - # FIXME: confirm - self._validate_index(index) - logger.debug( f"Querying vector database: collection_name={index} query={query} top_k={top_k}" ) @@ -153,19 +161,34 @@ class VikingDatabaseAdapter(BaseModel): client: VikingDatabase def _validate_index(self, index: str): - # TODO - pass + """ + Only English letters, numbers, and underscores (_) are allowed. + It must start with an English letter and cannot be empty. Length requirement: [1, 128]. + For details, please see: https://www.volcengine.com/docs/84313/1254542?lang=zh + """ + if not ( + isinstance(index, str) + and 0 < len(index) <= 128 + and re.fullmatch(r"^[a-zA-Z][a-zA-Z0-9_]*$", index) + ): + raise ValueError( + "The index name does not conform to the rules: it must start with an English letter, contain only letters, numbers, and underscores, and have a length of 1-128." + ) def get_or_create_collection(self, collection_name: str): if not self.client.collection_exists(collection_name): + logger.warning( + f"Collection {collection_name} does not exist, creating a new collection." + ) self.client.create_collection(collection_name) - # FIXME + # After creation, it is necessary to wait for a while. count = 0 while not self.client.collection_exists(collection_name): + print("here") time.sleep(1) count += 1 - if count > 50: + if count > 60: raise TimeoutError( f"Collection {collection_name} not created after 50 seconds" ) @@ -185,9 +208,8 @@ def query(self, query: str, index: str, top_k: int) -> list[str]: logger.debug(f"Querying Viking database: collection_name={index} query={query}") - # FIXME(): maybe do not raise, but just return [] if not self.client.collection_exists(index): - raise ValueError(f"Collection {index} does not exist") + return [] return self.client.query(query, collection_name=index, top_k=top_k) @@ -198,8 +220,14 @@ class VikingMemoryDatabaseAdapter(BaseModel): client: VikingMemoryDatabase def _validate_index(self, index: str): - # TODO - pass + if not ( + isinstance(index, str) + and 1 <= len(index) <= 128 + and re.fullmatch(r"^[a-zA-Z][a-zA-Z0-9_]*$", index) + ): + raise ValueError( + "The index name does not conform to the rules: it must start with an English letter, contain only letters, numbers, and underscores, and have a length of 1-128." + ) def add(self, data: list[str], index: str, **kwargs): self._validate_index(index) @@ -208,17 +236,16 @@ def add(self, data: list[str], index: str, **kwargs): f"Adding documents to Viking database memory: collection_name={index} data_len={len(data)}" ) - # TODO: parse user_id - self.client.add(data, collection_name=index) + self.client.add(data, collection_name=index, **kwargs) - def query(self, query: str, index: str, top_k: int): + def query(self, query: str, index: str, top_k: int, **kwargs): self._validate_index(index) logger.debug( f"Querying Viking database memory: collection_name={index} query={query} top_k={top_k}" ) - result = self.client.query(query, collection_name=index, top_k=top_k) + result = self.client.query(query, collection_name=index, top_k=top_k, **kwargs) return result @@ -245,8 +272,8 @@ def query(self, query: str, **kwargs): def get_knowledgebase_database_adapter(database_client: BaseDatabase): - return MAPPING[type(database_client)](database_client=database_client) + return MAPPING[type(database_client)](client=database_client) def get_long_term_memory_database_adapter(database_client: BaseDatabase): - return MAPPING[type(database_client)](database_client=database_client) + return MAPPING[type(database_client)](client=database_client) diff --git a/veadk/database/database_factory.py b/veadk/database/database_factory.py index 2dfff2e6..838b008b 100644 --- a/veadk/database/database_factory.py +++ b/veadk/database/database_factory.py @@ -69,12 +69,12 @@ def create(backend: str, config=None) -> BaseDatabase: return VikingDatabase() if config is None else VikingDatabase(config=config) if backend == DatabaseBackend.VIKING_MEM: - from .viking.viking_memory_db import VikingDatabaseMemory + from .viking.viking_memory_db import VikingMemoryDatabase return ( - VikingDatabaseMemory() + VikingMemoryDatabase() if config is None - else VikingDatabaseMemory(config=config) + else VikingMemoryDatabase(config=config) ) else: raise ValueError(f"Unsupported database type: {backend}") diff --git a/veadk/database/viking/viking_database.py b/veadk/database/viking/viking_database.py index 18a2b831..29f6d779 100644 --- a/veadk/database/viking/viking_database.py +++ b/veadk/database/viking/viking_database.py @@ -40,6 +40,7 @@ get_collections_path = "/api/knowledge/collection/info" doc_add_path = "/api/knowledge/doc/add" doc_info_path = "/api/knowledge/doc/info" +doc_del_path = "/api/collection/drop" class VolcengineTOSConfig(BaseModel): @@ -246,9 +247,23 @@ def add( } def delete(self, **kwargs: Any): - # collection_name = kwargs.get("collection_name") - # todo: delete vikingdb - ... + collection_name = kwargs.get("collection_name") + resource_id = kwargs.get("resource_id") + request_param = {"collection_name": collection_name, "resource_id": resource_id} + doc_del_req = prepare_request( + method="POST", path=doc_del_path, config=self.config, data=request_param + ) + rsp = requests.request( + method=doc_del_req.method, + url="http://{}{}".format(g_knowledge_base_domain, doc_del_req.path), + headers=doc_del_req.headers, + data=doc_del_req.body, + ) + result = rsp.json() + if result["code"] != 0: + logger.error(f"Error in add_doc: {result['message']}") + return {"error": result["message"]} + return {} def query(self, query: str, **kwargs: Any) -> list[str]: """ diff --git a/veadk/database/viking/viking_memory_db.py b/veadk/database/viking/viking_memory_db.py index d1ed37ad..bdffc6d0 100644 --- a/veadk/database/viking/viking_memory_db.py +++ b/veadk/database/viking/viking_memory_db.py @@ -34,7 +34,6 @@ logger = get_logger(__name__) -# FIXME class VikingMemConfig(BaseModel): volcengine_ak: Optional[str] = Field( default=getenv("VOLCENGINE_ACCESS_KEY"), @@ -54,8 +53,8 @@ class VikingMemConfig(BaseModel): ) -# ======= adapted from ... ======= -class VikingDBMemoryException(Exception): +# ======= adapted from https://github.com/volcengine/mcp-server/blob/main/server/mcp_server_vikingdb_memory/src/mcp_server_vikingdb_memory/common/memory_client.py ======= +class VikingMemoryException(Exception): def __init__(self, code, request_id, message=None): self.code = code self.request_id = request_id @@ -67,15 +66,15 @@ def __str__(self): return self.message -class VikingDBMemoryService(Service): +class VikingMemoryService(Service): _instance_lock = threading.Lock() def __new__(cls, *args, **kwargs): - if not hasattr(VikingDBMemoryService, "_instance"): - with VikingDBMemoryService._instance_lock: - if not hasattr(VikingDBMemoryService, "_instance"): - VikingDBMemoryService._instance = object.__new__(cls) - return VikingDBMemoryService._instance + if not hasattr(VikingMemoryService, "_instance"): + with VikingMemoryService._instance_lock: + if not hasattr(VikingMemoryService, "_instance"): + VikingMemoryService._instance = object.__new__(cls) + return VikingMemoryService._instance def __init__( self, @@ -88,11 +87,11 @@ def __init__( connection_timeout=30, socket_timeout=30, ): - self.service_info = VikingDBMemoryService.get_service_info( + self.service_info = VikingMemoryService.get_service_info( host, region, scheme, connection_timeout, socket_timeout ) - self.api_info = VikingDBMemoryService.get_api_info() - super(VikingDBMemoryService, self).__init__(self.service_info, self.api_info) + self.api_info = VikingMemoryService.get_api_info() + super(VikingMemoryService, self).__init__(self.service_info, self.api_info) if ak: self.set_ak(ak) if sk: @@ -102,12 +101,12 @@ def __init__( try: self.get_body("Ping", {}, json.dumps({})) except Exception as e: - raise VikingDBMemoryException( + raise VikingMemoryException( 1000028, "missed", "host or region is incorrect: {}".format(str(e)) ) from None def setHeader(self, header): - api_info = VikingDBMemoryService.get_api_info() + api_info = VikingMemoryService.get_api_info() for key in api_info: for item in header: api_info[key].header[item] = header[item] @@ -213,17 +212,17 @@ def get_body_exception(self, api, params, body): try: res_json = json.loads(e.args[0].decode("utf-8")) except Exception: - raise VikingDBMemoryException( + raise VikingMemoryException( 1000028, "missed", "json load res error, res:{}".format(str(e)) ) from None code = res_json.get("code", 1000028) request_id = res_json.get("request_id", 1000028) message = res_json.get("message", None) - raise VikingDBMemoryException(code, request_id, message) + raise VikingMemoryException(code, request_id, message) if res == "": - raise VikingDBMemoryException( + raise VikingMemoryException( 1000028, "missed", "empty response due to unknown error, please contact customer service", @@ -237,15 +236,15 @@ def get_exception(self, api, params): try: res_json = json.loads(e.args[0].decode("utf-8")) except Exception: - raise VikingDBMemoryException( + raise VikingMemoryException( 1000028, "missed", "json load res error, res:{}".format(str(e)) ) from None code = res_json.get("code", 1000028) request_id = res_json.get("request_id", 1000028) message = res_json.get("message", None) - raise VikingDBMemoryException(code, request_id, message) + raise VikingMemoryException(code, request_id, message) if res == "": - raise VikingDBMemoryException( + raise VikingMemoryException( 1000028, "missed", "empty response due to unknown error, please contact customer service", @@ -365,7 +364,7 @@ def format_milliseconds(timestamp_ms): return dt.strftime("%Y%m%d %H:%M:%S") -# ======= adapted from ... ======= +# ======= adapted from https://github.com/volcengine/mcp-server/blob/main/server/mcp_server_vikingdb_memory/src/mcp_server_vikingdb_memory/common/memory_client.py ======= class VikingMemoryDatabase(BaseModel, BaseDatabase): @@ -375,7 +374,7 @@ class VikingMemoryDatabase(BaseModel, BaseDatabase): ) def model_post_init(self, context: Any, /) -> None: - self._vm = VikingDBMemoryService( + self._vm = VikingMemoryService( ak=self.config.volcengine_ak, sk=self.config.volcengine_sk ) @@ -516,8 +515,8 @@ def query(self, query: str, **kwargs: Any) -> list[str]: assert collection_name is not None, "collection_name is required" user_id = kwargs.get("user_id") assert user_id is not None, "user_id is required" - - resp = self.search_memory(collection_name, query, user_id=user_id) + top_k = kwargs.get("top_k", 5) + resp = self.search_memory(collection_name, query, user_id=user_id, top_k=top_k) return resp def delete(self, **kwargs: Any): diff --git a/veadk/knowledgebase/knowledgebase.py b/veadk/knowledgebase/knowledgebase.py index a95670fa..6d4cfdf1 100644 --- a/veadk/knowledgebase/knowledgebase.py +++ b/veadk/knowledgebase/knowledgebase.py @@ -58,8 +58,12 @@ def add( In addition, if you upload data of the bytes type, for example, if you read the file stream of a pdf, then you need to pass an additional parameter file_ext = '.pdf'. """ - # TODO: add check for data type - ... + if self.backend != "viking" and not ( + isinstance(data, str) or isinstance(data, list) + ): + raise ValueError( + "Only vikingdb supports uploading files or file characters." + ) index = build_knowledgebase_index(app_name) @@ -73,7 +77,8 @@ def search(self, query: str, app_name: str, top_k: int = None) -> list[str]: logger.info( f"Searching knowledgebase: app_name={app_name} query={query} top_k={top_k}" ) - result = self.adapter.query(query=query, app_name=app_name, top_k=top_k) + index = build_knowledgebase_index(app_name) + result = self.adapter.query(query=query, index=index, top_k=top_k) if len(result) == 0: logger.warning(f"No documents found in knowledgebase. Query: {query}") return result diff --git a/veadk/memory/long_term_memory.py b/veadk/memory/long_term_memory.py index 7f63a17d..1a5ae0a9 100644 --- a/veadk/memory/long_term_memory.py +++ b/veadk/memory/long_term_memory.py @@ -97,7 +97,10 @@ async def add_session_to_memory( ) # check if viking memory database, should give a user id: if/else - self.adapter.add(data=event_strings, index=index) + if self.backend == "viking_mem": + self.adapter.add(data=event_strings, index=index, user_id=session.user_id) + else: + self.adapter.add(data=event_strings, index=index) logger.info( f"Added {len(event_strings)} events to long term memory: index={index}" @@ -112,15 +115,14 @@ async def search_memory(self, *, app_name: str, user_id: str, query: str): ) # user id if viking memory db - memory_chunks = self.adapter.query(query=query, index=index, top_k=self.top_k) - - # if len(memory_chunks) == 0: - # logger.info(f"Found no memory chunks for query: {query} index={index}") - # return SearchMemoryResponse() - - # logger.info( - # f"Found {len(memory_chunks)} memory chunks for query: {query} index={index}" - # ) + if self.backend == "viking_mem": + memory_chunks = self.adapter.query( + query=query, index=index, top_k=self.top_k, user_id=user_id + ) + else: + memory_chunks = self.adapter.query( + query=query, index=index, top_k=self.top_k + ) memory_events = [] for memory in memory_chunks: From 05f3e403c72e2172d01ad63d45f56a2ac67207af Mon Sep 17 00:00:00 2001 From: "hanzhi.421" Date: Mon, 11 Aug 2025 09:32:06 +0800 Subject: [PATCH 6/6] fix: fix database dpendency issues --- veadk/database/database_adapter.py | 70 +++++++++++++++--------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/veadk/database/database_adapter.py b/veadk/database/database_adapter.py index d45d9ca3..28554787 100644 --- a/veadk/database/database_adapter.py +++ b/veadk/database/database_adapter.py @@ -14,25 +14,18 @@ import re import time from typing import BinaryIO, TextIO - -from pydantic import BaseModel, ConfigDict - from veadk.database.base_database import BaseDatabase -from veadk.database.kv.redis_database import RedisDatabase -from veadk.database.local_database import LocalDataBase -from veadk.database.relational.mysql_database import MysqlDatabase -from veadk.database.vector.opensearch_vector_database import OpenSearchVectorDatabase -from veadk.database.viking.viking_database import VikingDatabase -from veadk.database.viking.viking_memory_db import VikingMemoryDatabase + from veadk.utils.logger import get_logger logger = get_logger(__name__) -class KVDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) +class KVDatabaseAdapter: + def __init__(self, client): + from veadk.database.kv.redis_database import RedisDatabase - client: RedisDatabase + self.client: RedisDatabase = client def add(self, data: list[str], index: str): logger.debug(f"Adding documents to Redis database: index={index}") @@ -61,10 +54,11 @@ def query(self, query: str, index: str, top_k: int = 0) -> list[str]: raise e -class RelationalDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) +class RelationalDatabaseAdapter: + def __init__(self, client): + from veadk.database.relational.mysql_database import MysqlDatabase - client: MysqlDatabase + self.client: MysqlDatabase = client def create_table(self, table_name: str): logger.debug(f"Creating table for SQL database: table_name={table_name}") @@ -114,10 +108,13 @@ def query(self, query: str, index: str, top_k: int) -> list[str]: return [item["data"] for item in results] -class VectorDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) +class VectorDatabaseAdapter: + def __init__(self, client): + from veadk.database.vector.opensearch_vector_database import ( + OpenSearchVectorDatabase, + ) - client: OpenSearchVectorDatabase + self.client: OpenSearchVectorDatabase = client def _validate_index(self, index: str): """ @@ -155,10 +152,11 @@ def query(self, query: str, index: str, top_k: int) -> list[str]: ) -class VikingDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) +class VikingDatabaseAdapter: + def __init__(self, client): + from veadk.database.viking.viking_database import VikingDatabase - client: VikingDatabase + self.client: VikingDatabase = client def _validate_index(self, index: str): """ @@ -214,10 +212,11 @@ def query(self, query: str, index: str, top_k: int) -> list[str]: return self.client.query(query, collection_name=index, top_k=top_k) -class VikingMemoryDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) +class VikingMemoryDatabaseAdapter: + def __init__(self, client): + from veadk.database.viking.viking_memory_db import VikingMemoryDatabase - client: VikingMemoryDatabase + self.client: VikingMemoryDatabase = client def _validate_index(self, index: str): if not ( @@ -249,10 +248,11 @@ def query(self, query: str, index: str, top_k: int, **kwargs): return result -class LocalDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) +class LocalDatabaseAdapter: + def __init__(self, client): + from veadk.database.local_database import LocalDataBase - client: LocalDataBase + self.client: LocalDataBase = client def add(self, data: list[str], **kwargs): self.client.add(data) @@ -262,18 +262,18 @@ def query(self, query: str, **kwargs): MAPPING = { - RedisDatabase: KVDatabaseAdapter, - MysqlDatabase: RelationalDatabaseAdapter, - LocalDataBase: LocalDatabaseAdapter, - VikingDatabase: VikingDatabaseAdapter, - OpenSearchVectorDatabase: VectorDatabaseAdapter, - VikingMemoryDatabase: VikingMemoryDatabaseAdapter, + "RedisDatabase": KVDatabaseAdapter, + "MysqlDatabase": RelationalDatabaseAdapter, + "LocalDataBase": LocalDatabaseAdapter, + "VikingDatabase": VikingDatabaseAdapter, + "OpenSearchVectorDatabase": VectorDatabaseAdapter, + "VikingMemoryDatabase": VikingMemoryDatabaseAdapter, } def get_knowledgebase_database_adapter(database_client: BaseDatabase): - return MAPPING[type(database_client)](client=database_client) + return MAPPING[type(database_client).__name__](client=database_client) def get_long_term_memory_database_adapter(database_client: BaseDatabase): - return MAPPING[type(database_client)](client=database_client) + return MAPPING[type(database_client).__name__](client=database_client)