diff --git a/tests/test_knowledgebase.py b/tests/test_knowledgebase.py index ec4518ed..c7f33377 100644 --- a/tests/test_knowledgebase.py +++ b/tests/test_knowledgebase.py @@ -33,5 +33,3 @@ async def test_knowledgebase(): ) res = "".join(res_list) assert key in res, f"Test failed for backend local res is {res}" - assert key in res, f"Test failed for backend local res is {res}" - assert key in res, f"Test failed for backend local res is {res}" diff --git a/tests/test_long_term_memory.py b/tests/test_long_term_memory.py index b088a2ce..12532825 100644 --- a/tests/test_long_term_memory.py +++ b/tests/test_long_term_memory.py @@ -49,7 +49,7 @@ async def test_long_term_memory(): events=[ Event( invocation_id="test_invocation_id", - author=agent.name, + author="user", branch=None, content=types.Content( parts=[types.Part(text="My name is Alice.")], diff --git a/veadk/database/database_adapter.py b/veadk/database/database_adapter.py new file mode 100644 index 00000000..28554787 --- /dev/null +++ b/veadk/database/database_adapter.py @@ -0,0 +1,279 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +import time +from typing import BinaryIO, TextIO +from veadk.database.base_database import BaseDatabase + +from veadk.utils.logger import get_logger + +logger = get_logger(__name__) + + +class KVDatabaseAdapter: + def __init__(self, client): + from veadk.database.kv.redis_database import RedisDatabase + + self.client: RedisDatabase = client + + def add(self, data: list[str], index: str): + logger.debug(f"Adding documents to Redis database: index={index}") + + try: + for _data in data: + self.client.add(key=index, value=_data) + logger.debug(f"Added {len(data)} texts to Redis database: index={index}") + except Exception as e: + logger.error( + f"Failed to add data to Redis database: index={index} error={e}" + ) + raise e + + def query(self, query: str, index: str, top_k: int = 0) -> list[str]: + logger.debug(f"Querying Redis database: index={index} query={query}") + + # ignore top_k, as KV search only return one result + _ = top_k + + try: + result = self.client.query(key=index, query=query) + return result + except Exception as e: + logger.error(f"Failed to search from Redis: index={index} error={e}") + raise e + + +class RelationalDatabaseAdapter: + def __init__(self, client): + from veadk.database.relational.mysql_database import MysqlDatabase + + self.client: MysqlDatabase = client + + def create_table(self, table_name: str): + logger.debug(f"Creating table for SQL database: table_name={table_name}") + + sql = f""" + CREATE TABLE `{table_name}` ( + `id` BIGINT AUTO_INCREMENT PRIMARY KEY, + `data` TEXT NOT NULL, + `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) ENGINE=InnoDB DEFAULT CHARSET={self.client.config.charset}; + """ + self.client.add(sql) + + def add(self, data: list[str], index: str): + logger.debug( + f"Adding documents to SQL database: table_name={index} data_len={len(data)}" + ) + + if not self.client.table_exists(index): + logger.warning(f"Table {index} does not exist, creating a new table.") + self.create_table(index) + + for _data in data: + sql = f""" + INSERT INTO `{index}` (`data`) + VALUES (%s); + """ + self.client.add(sql, params=(_data,)) + logger.debug(f"Added {len(data)} texts to table {index}.") + + def query(self, query: str, index: str, top_k: int) -> list[str]: + logger.debug( + f"Querying SQL database: table_name={index} query={query} top_k={top_k}" + ) + + if not self.client.table_exists(index): + logger.warning( + f"Querying SQL database, but table `{index}` does not exist, returning empty list." + ) + return [] + + sql = f""" + SELECT `data` FROM `{index}` ORDER BY `created_at` DESC LIMIT {top_k}; + """ + results = self.client.query(sql) + + return [item["data"] for item in results] + + +class VectorDatabaseAdapter: + def __init__(self, client): + from veadk.database.vector.opensearch_vector_database import ( + OpenSearchVectorDatabase, + ) + + self.client: OpenSearchVectorDatabase = client + + def _validate_index(self, index: str): + """ + Verify whether the string conforms to the naming rules of index_name in OpenSearch. + https://docs.opensearch.org/2.8/api-reference/index-apis/create-index/ + """ + if not ( + isinstance(index, str) + and not index.startswith(("_", "-")) + and index.islower() + and re.match(r"^[a-z0-9_\-.]+$", index) + ): + raise ValueError( + "The index name does not conform to the naming rules of OpenSearch" + ) + + def add(self, data: list[str], index: str): + self._validate_index(index) + + logger.debug( + f"Adding documents to vector database: index={index} data_len={len(data)}" + ) + + self.client.add(data, collection_name=index) + + def query(self, query: str, index: str, top_k: int) -> list[str]: + logger.debug( + f"Querying vector database: collection_name={index} query={query} top_k={top_k}" + ) + + return self.client.query( + query=query, + collection_name=index, + top_k=top_k, + ) + + +class VikingDatabaseAdapter: + def __init__(self, client): + from veadk.database.viking.viking_database import VikingDatabase + + self.client: VikingDatabase = client + + def _validate_index(self, index: str): + """ + Only English letters, numbers, and underscores (_) are allowed. + It must start with an English letter and cannot be empty. Length requirement: [1, 128]. + For details, please see: https://www.volcengine.com/docs/84313/1254542?lang=zh + """ + if not ( + isinstance(index, str) + and 0 < len(index) <= 128 + and re.fullmatch(r"^[a-zA-Z][a-zA-Z0-9_]*$", index) + ): + raise ValueError( + "The index name does not conform to the rules: it must start with an English letter, contain only letters, numbers, and underscores, and have a length of 1-128." + ) + + def get_or_create_collection(self, collection_name: str): + if not self.client.collection_exists(collection_name): + logger.warning( + f"Collection {collection_name} does not exist, creating a new collection." + ) + self.client.create_collection(collection_name) + + # After creation, it is necessary to wait for a while. + count = 0 + while not self.client.collection_exists(collection_name): + print("here") + time.sleep(1) + count += 1 + if count > 60: + raise TimeoutError( + f"Collection {collection_name} not created after 50 seconds" + ) + + def add( + self, data: str | list[str] | TextIO | BinaryIO | bytes, index: str, **kwargs + ): + self._validate_index(index) + + logger.debug(f"Adding documents to Viking database: collection_name={index}") + + self.get_or_create_collection(index) + self.client.add(data, collection_name=index, **kwargs) + + def query(self, query: str, index: str, top_k: int) -> list[str]: + self._validate_index(index) + + logger.debug(f"Querying Viking database: collection_name={index} query={query}") + + if not self.client.collection_exists(index): + return [] + + return self.client.query(query, collection_name=index, top_k=top_k) + + +class VikingMemoryDatabaseAdapter: + def __init__(self, client): + from veadk.database.viking.viking_memory_db import VikingMemoryDatabase + + self.client: VikingMemoryDatabase = client + + def _validate_index(self, index: str): + if not ( + isinstance(index, str) + and 1 <= len(index) <= 128 + and re.fullmatch(r"^[a-zA-Z][a-zA-Z0-9_]*$", index) + ): + raise ValueError( + "The index name does not conform to the rules: it must start with an English letter, contain only letters, numbers, and underscores, and have a length of 1-128." + ) + + def add(self, data: list[str], index: str, **kwargs): + self._validate_index(index) + + logger.debug( + f"Adding documents to Viking database memory: collection_name={index} data_len={len(data)}" + ) + + self.client.add(data, collection_name=index, **kwargs) + + def query(self, query: str, index: str, top_k: int, **kwargs): + self._validate_index(index) + + logger.debug( + f"Querying Viking database memory: collection_name={index} query={query} top_k={top_k}" + ) + + result = self.client.query(query, collection_name=index, top_k=top_k, **kwargs) + return result + + +class LocalDatabaseAdapter: + def __init__(self, client): + from veadk.database.local_database import LocalDataBase + + self.client: LocalDataBase = client + + def add(self, data: list[str], **kwargs): + self.client.add(data) + + def query(self, query: str, **kwargs): + return self.client.query(query, **kwargs) + + +MAPPING = { + "RedisDatabase": KVDatabaseAdapter, + "MysqlDatabase": RelationalDatabaseAdapter, + "LocalDataBase": LocalDatabaseAdapter, + "VikingDatabase": VikingDatabaseAdapter, + "OpenSearchVectorDatabase": VectorDatabaseAdapter, + "VikingMemoryDatabase": VikingMemoryDatabaseAdapter, +} + + +def get_knowledgebase_database_adapter(database_client: BaseDatabase): + return MAPPING[type(database_client).__name__](client=database_client) + + +def get_long_term_memory_database_adapter(database_client: BaseDatabase): + return MAPPING[type(database_client).__name__](client=database_client) diff --git a/veadk/database/database_factory.py b/veadk/database/database_factory.py index 2dfff2e6..838b008b 100644 --- a/veadk/database/database_factory.py +++ b/veadk/database/database_factory.py @@ -69,12 +69,12 @@ def create(backend: str, config=None) -> BaseDatabase: return VikingDatabase() if config is None else VikingDatabase(config=config) if backend == DatabaseBackend.VIKING_MEM: - from .viking.viking_memory_db import VikingDatabaseMemory + from .viking.viking_memory_db import VikingMemoryDatabase return ( - VikingDatabaseMemory() + VikingMemoryDatabase() if config is None - else VikingDatabaseMemory(config=config) + else VikingMemoryDatabase(config=config) ) else: raise ValueError(f"Unsupported database type: {backend}") diff --git a/veadk/database/viking/viking_database.py b/veadk/database/viking/viking_database.py index e4654a7b..29f6d779 100644 --- a/veadk/database/viking/viking_database.py +++ b/veadk/database/viking/viking_database.py @@ -40,6 +40,7 @@ get_collections_path = "/api/knowledge/collection/info" doc_add_path = "/api/knowledge/doc/add" doc_info_path = "/api/knowledge/doc/info" +doc_del_path = "/api/collection/drop" class VolcengineTOSConfig(BaseModel): @@ -215,7 +216,12 @@ def _add_doc(self, collection_name: str, tos_url: str, doc_id: str, **kwargs: An return doc_id - def add(self, data: str | list[str] | TextIO | BinaryIO | bytes, **kwargs: Any): + def add( + self, + data: str | list[str] | TextIO | BinaryIO | bytes, + collection_name: str, + **kwargs, + ): """ Args: data: str, file path or file stream: Both file or file.read() are acceptable. @@ -226,8 +232,6 @@ def add(self, data: str | list[str] | TextIO | BinaryIO | bytes, **kwargs: Any): "doc_id": "", } """ - collection_name = kwargs.get("collection_name") - assert collection_name is not None, "collection_name is required" status, tos_url = self._upload_to_tos(data=data, **kwargs) if status != 200: @@ -243,9 +247,23 @@ def add(self, data: str | list[str] | TextIO | BinaryIO | bytes, **kwargs: Any): } def delete(self, **kwargs: Any): - # collection_name = kwargs.get("collection_name") - # todo: delete vikingdb - ... + collection_name = kwargs.get("collection_name") + resource_id = kwargs.get("resource_id") + request_param = {"collection_name": collection_name, "resource_id": resource_id} + doc_del_req = prepare_request( + method="POST", path=doc_del_path, config=self.config, data=request_param + ) + rsp = requests.request( + method=doc_del_req.method, + url="http://{}{}".format(g_knowledge_base_domain, doc_del_req.path), + headers=doc_del_req.headers, + data=doc_del_req.body, + ) + result = rsp.json() + if result["code"] != 0: + logger.error(f"Error in add_doc: {result['message']}") + return {"error": result["message"]} + return {} def query(self, query: str, **kwargs: Any) -> list[str]: """ diff --git a/veadk/database/viking/viking_memory_db.py b/veadk/database/viking/viking_memory_db.py index b5c23430..bdffc6d0 100644 --- a/veadk/database/viking/viking_memory_db.py +++ b/veadk/database/viking/viking_memory_db.py @@ -53,7 +53,8 @@ class VikingMemConfig(BaseModel): ) -class VikingDBMemoryException(Exception): +# ======= adapted from https://github.com/volcengine/mcp-server/blob/main/server/mcp_server_vikingdb_memory/src/mcp_server_vikingdb_memory/common/memory_client.py ======= +class VikingMemoryException(Exception): def __init__(self, code, request_id, message=None): self.code = code self.request_id = request_id @@ -65,15 +66,15 @@ def __str__(self): return self.message -class VikingDBMemoryService(Service): +class VikingMemoryService(Service): _instance_lock = threading.Lock() def __new__(cls, *args, **kwargs): - if not hasattr(VikingDBMemoryService, "_instance"): - with VikingDBMemoryService._instance_lock: - if not hasattr(VikingDBMemoryService, "_instance"): - VikingDBMemoryService._instance = object.__new__(cls) - return VikingDBMemoryService._instance + if not hasattr(VikingMemoryService, "_instance"): + with VikingMemoryService._instance_lock: + if not hasattr(VikingMemoryService, "_instance"): + VikingMemoryService._instance = object.__new__(cls) + return VikingMemoryService._instance def __init__( self, @@ -86,11 +87,11 @@ def __init__( connection_timeout=30, socket_timeout=30, ): - self.service_info = VikingDBMemoryService.get_service_info( + self.service_info = VikingMemoryService.get_service_info( host, region, scheme, connection_timeout, socket_timeout ) - self.api_info = VikingDBMemoryService.get_api_info() - super(VikingDBMemoryService, self).__init__(self.service_info, self.api_info) + self.api_info = VikingMemoryService.get_api_info() + super(VikingMemoryService, self).__init__(self.service_info, self.api_info) if ak: self.set_ak(ak) if sk: @@ -100,12 +101,12 @@ def __init__( try: self.get_body("Ping", {}, json.dumps({})) except Exception as e: - raise VikingDBMemoryException( + raise VikingMemoryException( 1000028, "missed", "host or region is incorrect: {}".format(str(e)) ) from None def setHeader(self, header): - api_info = VikingDBMemoryService.get_api_info() + api_info = VikingMemoryService.get_api_info() for key in api_info: for item in header: api_info[key].header[item] = header[item] @@ -211,17 +212,17 @@ def get_body_exception(self, api, params, body): try: res_json = json.loads(e.args[0].decode("utf-8")) except Exception: - raise VikingDBMemoryException( + raise VikingMemoryException( 1000028, "missed", "json load res error, res:{}".format(str(e)) ) from None code = res_json.get("code", 1000028) request_id = res_json.get("request_id", 1000028) message = res_json.get("message", None) - raise VikingDBMemoryException(code, request_id, message) + raise VikingMemoryException(code, request_id, message) if res == "": - raise VikingDBMemoryException( + raise VikingMemoryException( 1000028, "missed", "empty response due to unknown error, please contact customer service", @@ -235,15 +236,15 @@ def get_exception(self, api, params): try: res_json = json.loads(e.args[0].decode("utf-8")) except Exception: - raise VikingDBMemoryException( + raise VikingMemoryException( 1000028, "missed", "json load res error, res:{}".format(str(e)) ) from None code = res_json.get("code", 1000028) request_id = res_json.get("request_id", 1000028) message = res_json.get("message", None) - raise VikingDBMemoryException(code, request_id, message) + raise VikingMemoryException(code, request_id, message) if res == "": - raise VikingDBMemoryException( + raise VikingMemoryException( 1000028, "missed", "empty response due to unknown error, please contact customer service", @@ -363,14 +364,17 @@ def format_milliseconds(timestamp_ms): return dt.strftime("%Y%m%d %H:%M:%S") -class VikingDatabaseMemory(BaseModel, BaseDatabase): +# ======= adapted from https://github.com/volcengine/mcp-server/blob/main/server/mcp_server_vikingdb_memory/src/mcp_server_vikingdb_memory/common/memory_client.py ======= + + +class VikingMemoryDatabase(BaseModel, BaseDatabase): config: VikingMemConfig = Field( default_factory=VikingMemConfig, description="VikingDB configuration", ) def model_post_init(self, context: Any, /) -> None: - self._vm = VikingDBMemoryService( + self._vm = VikingMemoryService( ak=self.config.volcengine_ak, sk=self.config.volcengine_sk ) @@ -511,8 +515,8 @@ def query(self, query: str, **kwargs: Any) -> list[str]: assert collection_name is not None, "collection_name is required" user_id = kwargs.get("user_id") assert user_id is not None, "user_id is required" - - resp = self.search_memory(collection_name, query, user_id=user_id) + top_k = kwargs.get("top_k", 5) + resp = self.search_memory(collection_name, query, user_id=user_id, top_k=top_k) return resp def delete(self, **kwargs: Any): diff --git a/veadk/knowledgebase/knowledgebase.py b/veadk/knowledgebase/knowledgebase.py index 7d69dfd5..6d4cfdf1 100644 --- a/veadk/knowledgebase/knowledgebase.py +++ b/veadk/knowledgebase/knowledgebase.py @@ -14,28 +14,36 @@ from typing import BinaryIO, Literal, TextIO +from veadk.database.database_adapter import get_knowledgebase_database_adapter from veadk.database.database_factory import DatabaseFactory from veadk.utils.logger import get_logger -from .knowledgebase_database_adapter import get_knowledgebase_adapter - logger = get_logger(__name__) +def build_knowledgebase_index(app_name: str): + return f"veadk_kb_{app_name}" + + class KnowledgeBase: def __init__( self, backend: Literal["local", "opensearch", "viking", "redis", "mysql"] = "local", - top_k: int = 5, + top_k: int = 10, db_config=None, ): - logger.debug(f"Create knowledgebase, backend is {backend}") + logger.info( + f"Initializing knowledgebase: backend={backend} top_k={top_k} db_config={db_config}" + ) + self.backend = backend self.top_k = top_k self.db_client = DatabaseFactory.create(backend=backend, config=db_config) - self.adapter = get_knowledgebase_adapter(backend)( - database_client=self.db_client + self.adapter = get_knowledgebase_database_adapter(self.db_client) + + logger.info( + f"Initialized knowledgebase: db_client={self.db_client} adapter={self.adapter}" ) def add( @@ -46,38 +54,31 @@ def add( ): """ Add documents to the vector database. - You can only upload files or file characters when the adapter type you use is vikingdb. + You can only upload files or file characters when the adapter type used is vikingdb. In addition, if you upload data of the bytes type, for example, if you read the file stream of a pdf, then you need to pass an additional parameter file_ext = '.pdf'. """ - kwargs.pop("session_id", None) # remove session_id - self.adapter.add( - data, app_name, user_id="user_id", session_id="session_id", **kwargs - ) + if self.backend != "viking" and not ( + isinstance(data, str) or isinstance(data, list) + ): + raise ValueError( + "Only vikingdb supports uploading files or file characters." + ) - def search(self, query: str, app_name: str, top_k: int = None) -> list[str]: - """Retrieve documents similar to the query text in the vector database. + index = build_knowledgebase_index(app_name) - Args: - query (str): The query text to be retrieved (e.g., "Who proposed the Turing machine model?") + logger.info(f"Adding documents to knowledgebase: index={index}") - Returns: - list[str]: A list of the top most similar document contents retrieved (sorted by vector similarity) - """ + self.adapter.add(data=data, index=index) + + def search(self, query: str, app_name: str, top_k: int = None) -> list[str]: top_k = self.top_k if top_k is None else top_k - result = self.adapter.query( - query=query, app_name=app_name, user_id="user_id", top_k=top_k + logger.info( + f"Searching knowledgebase: app_name={app_name} query={query} top_k={top_k}" ) + index = build_knowledgebase_index(app_name) + result = self.adapter.query(query=query, index=index, top_k=top_k) if len(result) == 0: logger.warning(f"No documents found in knowledgebase. Query: {query}") return result - - def delete(self, app_name: str, user_id: str, session_id: str): - """Delete documents in the vector database. - Args: - app_name (str): The name of the application - user_id (str): The user ID - session_id (str): The session ID - """ - self.adapter.delete(app_name=app_name, user_id=user_id, session_id=session_id) diff --git a/veadk/knowledgebase/knowledgebase_database_adapter.py b/veadk/knowledgebase/knowledgebase_database_adapter.py deleted file mode 100644 index d81e28e8..00000000 --- a/veadk/knowledgebase/knowledgebase_database_adapter.py +++ /dev/null @@ -1,259 +0,0 @@ -# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Knowledgebase may use different databases, so we need to create -an adapter to abstract the database operations. -""" - -import re -import time -from typing import BinaryIO, TextIO - -from pydantic import BaseModel, ConfigDict - -from veadk.database.base_database import BaseDatabase -from veadk.database.database_factory import DatabaseBackend -from veadk.utils.logger import get_logger - -logger = get_logger(__name__) - - -def format_collection_name(collection_name: str) -> str: - replaced_str = re.sub(r"[- ]", "_", collection_name) - return re.sub(r"[^a-z0-9_]", "", replaced_str).lower() - - -def get_knowledgebase_adapter(backend: str): - if backend == DatabaseBackend.REDIS: - return KnowledgebaseKVDatabaseAdapter - elif backend == DatabaseBackend.MYSQL: - return KnowledgebaseRelationalDatabaseAdapter - elif backend == DatabaseBackend.OPENSEARCH: - return KnowledgebaseVectorDatabaseAdapter - elif backend == DatabaseBackend.LOCAL: - return KnowledgebaseLocalDatabaseAdapter - elif backend == DatabaseBackend.VIKING: - return KnowledgebaseVikingDatabaseAdapter - else: - raise ValueError(f"Unknown backend: {backend}") - - -class KnowledgebaseKVDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def add(self, content: list[str], app_name: str, user_id: str, session_id: str): - """Add texts to Redis. - - Key: app_name - Field: app_name:user_id - Value: text in List - """ - # key = f"{app_name}:{user_id}" - key = f"{app_name}" - - try: - for _content in content: - self.database_client.add(key, _content) - logger.debug( - f"Successfully added {len(content)} texts to Redis list key `{key}`." - ) - except Exception as e: - logger.error(f"Failed to add texts to Redis list key `{key}`: {e}") - raise e - - def query(self, query: str, app_name: str, user_id: str, **kwargs): - # key = f"{app_name}:{user_id}" - key = f"{app_name}" - top_k = 10 - - try: - result = self.database_client.query(key, query) - return result[-top_k:] - except Exception as e: - logger.error(f"Failed to search from Redis list key '{key}': {e}") - raise e - - def delete(self, app_name: str, user_id: str, session_id: str): - try: - # key = f"{app_name}:{user_id}:{session_id}" - key = f"{app_name}" - self.database_client.delete(key=key) - logger.info(f"Successfully deleted data for app {app_name}") - except Exception as e: - logger.error(f"Failed to delete data: {e}") - raise e - - -class KnowledgebaseRelationalDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def create_table(self, table_name: str): - sql = f""" - CREATE TABLE `{table_name}` ( - `id` BIGINT AUTO_INCREMENT PRIMARY KEY, - `data` TEXT NOT NULL, - `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) ENGINE=InnoDB DEFAULT CHARSET={self.database_client.config.charset}; - """ - self.database_client.add(sql) - - def add(self, content: list[str], app_name: str, user_id: str, session_id: str): - table = app_name - - if not self.database_client.table_exists(table): - logger.warning(f"Table {table} does not exist, creating...") - self.create_table(table) - - for _content in content: - sql = f""" - INSERT INTO `{table}` (`data`) - VALUES (%s); - """ - self.database_client.add(sql, params=(_content,)) - logger.info(f"Successfully added {len(content)} texts to table {table}.") - - def query(self, query: str, app_name: str, user_id: str, **kwargs): - """Search content from table app_name.""" - table = app_name - top_k = 10 - - if not self.database_client.table_exists(table): - logger.warning( - f"querying {query}, but table `{table}` does not exist, returning empty list." - ) - return [] - - sql = f""" - SELECT `data` FROM `{table}` ORDER BY `created_at` DESC LIMIT {top_k}; - """ - results = self.database_client.query(sql) - return [item["data"] for item in results] - - def delete(self, app_name: str, user_id: str, session_id: str): - table = app_name - try: - self.database_client.delete(table=table) - logger.info(f"Successfully deleted data from table {app_name}") - except Exception as e: - logger.error(f"Failed to delete data: {e}") - raise e - - -class KnowledgebaseVectorDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def add(self, content: list[str], app_name: str, user_id: str, session_id: str): - # collection_name = format_collection_name(f"{app_name}_{user_id}") - # knowledgebase is application specific - collection_name = format_collection_name(f"{app_name}") - self.database_client.add(content, collection_name=collection_name) - - def query(self, query: str, app_name: str, user_id: str, **kwargs): - # collection_name = format_collection_name(f"{app_name}_{user_id}") - # knowledgebase is application specific - collection_name = format_collection_name(f"{app_name}") - return self.database_client.query( - query, collection_name=collection_name, **kwargs - ) - - def delete(self, app_name: str, user_id: str, session_id: str): - # collection_name = format_collection_name(f"{app_name}_{user_id}") - # knowledgebase is application specific - collection_name = format_collection_name(f"{app_name}") - try: - self.database_client.delete(collection_name=collection_name) - logger.info( - f"Successfully deleted vector database collection for app {app_name}" - ) - except Exception as e: - logger.error(f"Failed to delete vector database collection: {e}") - raise e - - -class KnowledgebaseLocalDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def add(self, content: list[str], app_name: str, user_id: str, session_id: str): - self.database_client.add(content) - - def query(self, query: str, app_name: str, user_id: str, **kwargs): - return self.database_client.query(query, **kwargs) - - def delete(self, app_name: str, user_id: str, session_id: str): - try: - self.database_client.delete() - logger.info(f"Successfully cleared local database for app {app_name}") - except Exception as e: - logger.error(f"Failed to clear local database: {e}") - raise e - - -class KnowledgebaseVikingDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def get_or_create_collection(self, collection_name: str): - if not self.database_client.collection_exists(collection_name): - self.database_client.create_collection(collection_name) - count = 0 - while not self.database_client.collection_exists(collection_name): - time.sleep(1) - count += 1 - if count > 50: - raise TimeoutError( - f"Collection {collection_name} not created after 50 seconds" - ) - - def add( - self, - content: str | list[str] | TextIO | BinaryIO | bytes, - app_name: str, - user_id: str, - session_id: str, - **kwargs, - ): - # collection_name = format_collection_name(f"{app_name}_{user_id}") - collection_name = format_collection_name(f"{app_name}") - self.get_or_create_collection(collection_name) - self.database_client.add(content, collection_name=collection_name, **kwargs) - - def query(self, query: str, app_name: str, user_id: str, **kwargs): - collection_name = format_collection_name(f"{app_name}") - if not self.database_client.collection_exists(collection_name): - raise ValueError(f"Collection {collection_name} does not exist") - return self.database_client.query( - query, collection_name=collection_name, **kwargs - ) - - def delete(self, app_name: str, user_id: str, session_id: str): - # collection_name = format_collection_name(f"{app_name}_{user_id}") - collection_name = format_collection_name(f"{app_name}") - try: - self.database_client.delete(collection_name=collection_name) - logger.info( - f"Successfully deleted vector database collection for app {app_name}" - ) - except Exception as e: - logger.error(f"Failed to delete vector database collection: {e}") - raise e diff --git a/veadk/memory/long_term_memory.py b/veadk/memory/long_term_memory.py index e6e5d13a..1a5ae0a9 100644 --- a/veadk/memory/long_term_memory.py +++ b/veadk/memory/long_term_memory.py @@ -16,6 +16,7 @@ import json from typing import Literal +from google.adk.events.event import Event from google.adk.memory.base_memory_service import ( BaseMemoryService, SearchMemoryResponse, @@ -25,67 +26,103 @@ from google.genai import types from typing_extensions import override -from veadk.config import getenv from veadk.database import DatabaseFactory +from veadk.database.database_adapter import get_long_term_memory_database_adapter from veadk.utils.logger import get_logger -from .memory_database_adapter import get_memory_adapter - logger = get_logger(__name__) +def build_long_term_memory_index(app_name: str, user_id: str): + return f"{app_name}_{user_id}" + + class LongTermMemory(BaseMemoryService): def __init__( self, backend: Literal[ "local", "opensearch", "redis", "mysql", "viking" ] = "opensearch", - top_k: int = getenv("LONGTERM_MEMORY_TOP_K", 3), + top_k: int = 5, ): if backend == "viking": backend = "viking_mem" self.top_k = top_k self.backend = backend + logger.info( + f"Initializing long term memory: backend={self.backend} top_k={self.top_k}" + ) + self.db_client = DatabaseFactory.create( backend=backend, ) - logger.info(f"Long term memory backend is `{backend}`.") + self.adapter = get_long_term_memory_database_adapter(self.db_client) + + logger.info( + f"Initialized long term memory: db_client={self.db_client} adapter={self.adapter}" + ) + + def _filter_and_convert_events(self, events: list[Event]) -> list[str]: + final_events = [] + for event in events: + # filter: bad event + if not event.content or not event.content.parts: + continue + + # filter: only add user event to memory to enhance retrieve performance + if not event.author == "user": + continue + + # filter: discard function call and function response + if not event.content.parts[0].text: + continue + + # convert: to string-format for storage + message = event.content.model_dump(exclude_none=True, mode="json") - self.adapter = get_memory_adapter(backend)(database_client=self.db_client) + final_events.append(json.dumps(message)) + return final_events @override async def add_session_to_memory( self, session: Session, ): - event_list = [] - for event in session.events: - if not event.content or not event.content.parts: - continue + event_strings = self._filter_and_convert_events(session.events) + index = build_long_term_memory_index(session.app_name, session.user_id) - message = event.content.model_dump(exclude_none=True, mode="json") - if ( - "text" not in message["parts"][0] - ): # remove function_call & function_resp - continue - event_list.append(json.dumps(message)) - self.adapter.add( - event_list, - app_name=session.app_name, - user_id=session.user_id, - session_id=session.id, + logger.info( + f"Adding {len(event_strings)} events to long term memory: index={index}" + ) + + # check if viking memory database, should give a user id: if/else + if self.backend == "viking_mem": + self.adapter.add(data=event_strings, index=index, user_id=session.user_id) + else: + self.adapter.add(data=event_strings, index=index) + + logger.info( + f"Added {len(event_strings)} events to long term memory: index={index}" ) @override async def search_memory(self, *, app_name: str, user_id: str, query: str): - memory_chunks = self.adapter.query( - query=query, - app_name=app_name, - user_id=user_id, + index = build_long_term_memory_index(app_name, user_id) + + logger.info( + f"Searching long term memory: query={query} index={index} top_k={self.top_k}" ) - if len(memory_chunks) == 0: - return SearchMemoryResponse() + + # user id if viking memory db + if self.backend == "viking_mem": + memory_chunks = self.adapter.query( + query=query, index=index, top_k=self.top_k, user_id=user_id + ) + else: + memory_chunks = self.adapter.query( + query=query, index=index, top_k=self.top_k + ) memory_events = [] for memory in memory_chunks: @@ -94,26 +131,25 @@ async def search_memory(self, *, app_name: str, user_id: str, query: str): try: text = memory_dict["parts"][0]["text"] role = memory_dict["role"] - except KeyError as e: - logger.error( - f"Memory content: {memory_dict}. Error parsing memory: {e}" + except KeyError as _: + # prevent not a standard text-based event + logger.warning( + f"Memory content: {memory_dict}. Skip return this memory." ) continue except json.JSONDecodeError: + # prevent the memory string is not dumped by `Event` class text = memory role = "user" + memory_events.append( MemoryEntry( author="user", content=types.Content(parts=[types.Part(text=text)], role=role), ) ) - return SearchMemoryResponse(memories=memory_events) - @override - async def delete_memory(self, *, app_name: str, user_id: str): - self.adapter.delete( - app_name=app_name, - user_id=user_id, - session_id="", # session_id is not used in the adapter delete method + logger.info( + f"Return {len(memory_events)} memory events for query: {query} index={index}" ) + return SearchMemoryResponse(memories=memory_events) diff --git a/veadk/memory/memory_database_adapter.py b/veadk/memory/memory_database_adapter.py deleted file mode 100644 index d176377d..00000000 --- a/veadk/memory/memory_database_adapter.py +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Longterm memory may use different databases, so we need to create -an adapter to abstract the database operations. -""" - -import re - -from pydantic import BaseModel, ConfigDict - -from veadk.database.base_database import BaseDatabase -from veadk.database.database_factory import DatabaseBackend -from veadk.utils.logger import get_logger - -logger = get_logger(__name__) - - -def format_collection_name(collection_name: str) -> str: - replaced_str = re.sub(r"[- ]", "_", collection_name) - return re.sub(r"[^a-z0-9_]", "", replaced_str).lower() - - -def get_memory_adapter(backend: str): - if backend == DatabaseBackend.REDIS: - return MemoryKVDatabaseAdapter - elif backend == DatabaseBackend.MYSQL: - return MemoryRelationalDatabaseAdapter - elif backend == DatabaseBackend.OPENSEARCH: - return MemoryVectorDatabaseAdapter - elif backend == DatabaseBackend.LOCAL: - return MemoryLocalDatabaseAdapter - elif backend == DatabaseBackend.VIKING_MEM: - return MemoryVikingDBAdapter - else: - raise ValueError(f"Unknown backend: {backend}") - - -class MemoryKVDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def add(self, content: list[str], app_name: str, user_id: str, session_id: str): - """Add texts to Redis. - - Key: app_name - Field: app_name:user_id - Value: text in List - """ - key = f"{app_name}:{user_id}" - - try: - for _content in content: - self.database_client.add(key, _content) - logger.debug( - f"Successfully added {len(content)} texts to Redis list key `{key}`." - ) - except Exception as e: - logger.error(f"Failed to add texts to Redis list key `{key}`: {e}") - raise e - - def query(self, query: str, app_name: str, user_id: str): - key = f"{app_name}:{user_id}" - top_k = 10 - - try: - result = self.database_client.query(key, query) - # Get latest top_k records. - # The data is stored in a Redis list, and the latest data is at the end of the list. - return result[-top_k:] - except Exception as e: - logger.error(f"Failed to search from Redis list key '{key}': {e}") - raise e - - def delete(self, app_name: str, user_id: str, session_id: str): - try: - self.database_client.delete( - app_name=app_name, user_id=user_id, session_id=session_id - ) - logger.info( - f"Successfully deleted memory data for app {app_name}, user {user_id}, session {session_id}" - ) - except Exception as e: - logger.error(f"Failed to delete memory data: {e}") - raise e - - -class MemoryRelationalDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def create_table(self, table_name: str): - sql = f""" - CREATE TABLE `{table_name}` ( - `id` BIGINT AUTO_INCREMENT PRIMARY KEY, - `data` TEXT NOT NULL, - `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) ENGINE=InnoDB DEFAULT CHARSET={self.database_client.config.charset}; - """ - self.database_client.add(sql) - - def add(self, content: list[str], app_name: str, user_id: str, session_id: str): - table = f"{app_name}_{user_id}" - - if not self.database_client.table_exists(table): - logger.warning(f"Table {table} does not exist, creating...") - self.create_table(table) - - for _content in content: - sql = f""" - INSERT INTO `{table}` (`data`) - VALUES (%s); - """ - self.database_client.add(sql, params=(_content,)) - logger.info(f"Successfully added {len(content)} texts to table {table}.") - - def query(self, query: str, app_name: str, user_id: str): - """Search content from table app_name_user_id.""" - top_k = 10 - table = f"{app_name}_{user_id}" - - if not self.database_client.table_exists(table): - logger.warning( - f"querying {query}, but table `{table}` does not exist, returning empty list." - ) - return [] - - sql = f""" - SELECT `data` FROM `{table}` ORDER BY `created_at` DESC LIMIT {top_k}; - """ - results = self.database_client.query(sql) - return [item["data"] for item in results] - - def delete(self, app_name: str, user_id: str, session_id: str): - table = f"{app_name}_{user_id}" - try: - self.database_client.delete(table=table) - logger.info(f"Successfully deleted memory data from table {table}") - except Exception as e: - logger.error(f"Failed to delete memory data: {e}") - raise e - - -class MemoryVectorDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def add(self, content: list[str], app_name: str, user_id: str, session_id: str): - collection_name = format_collection_name(f"{app_name}_{user_id}") - self.database_client.add(content, collection_name=collection_name) - - def query(self, query: str, app_name: str, user_id: str): - collection_name = format_collection_name(f"{app_name}_{user_id}") - top_k = 10 - return self.database_client.query( - query, collection_name=collection_name, top_k=top_k - ) - - def delete(self, app_name: str, user_id: str, session_id: str): - collection_name = format_collection_name(f"{app_name}_{user_id}") - try: - self.database_client.delete(collection_name=collection_name) - logger.info( - f"Successfully deleted vector memory database collection for app {app_name}" - ) - except Exception as e: - logger.error(f"Failed to delete vector memory database collection: {e}") - raise e - - -class MemoryLocalDatabaseAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def add(self, content: list[str], app_name: str, user_id: str, session_id: str): - self.database_client.add(content) - - def query(self, query: str, app_name: str, user_id: str): - return self.database_client.query(query) - - def delete(self, app_name: str, user_id: str, session_id: str): - try: - self.database_client.delete() - logger.info( - f"Successfully cleared local memory database for app {app_name}" - ) - except Exception as e: - logger.error(f"Failed to clear local memory database: {e}") - raise e - - -class MemoryVikingDBAdapter(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - database_client: BaseDatabase - - def add( - self, content: list[str], app_name: str, user_id: str, session_id: str, **kwargs - ): - kwargs.pop("user_id", None) - - collection_name = format_collection_name(f"{app_name}_{user_id}") - self.database_client.add( - content, collection_name=collection_name, user_id=user_id, **kwargs - ) - - def query(self, query: str, app_name: str, user_id: str, **kwargs): - kwargs.pop("user_id", None) - - collection_name = format_collection_name(f"{app_name}_{user_id}") - result = self.database_client.query( - query, collection_name=collection_name, user_id=user_id, **kwargs - ) - return result - - def delete(self, app_name: str, user_id: str, session_id: str): - # collection_name = format_collection_name(f"{app_name}_{user_id}") - # todo: delete viking memory db - ...