diff --git a/veadk/integrations/ve_viking_db_memory/ve_viking_db_memory.py b/veadk/integrations/ve_viking_db_memory/ve_viking_db_memory.py index 04d1f481..3f4e4a49 100644 --- a/veadk/integrations/ve_viking_db_memory/ve_viking_db_memory.py +++ b/veadk/integrations/ve_viking_db_memory/ve_viking_db_memory.py @@ -13,14 +13,17 @@ # limitations under the License. import json +import os import threading -from veadk.utils.misc import getenv + from volcengine.ApiInfo import ApiInfo from volcengine.auth.SignerV4 import SignerV4 from volcengine.base.Service import Service from volcengine.Credentials import Credentials from volcengine.ServiceInfo import ServiceInfo +from veadk.utils.misc import getenv + class VikingDBMemoryException(Exception): def __init__(self, code, request_id, message=None): @@ -56,7 +59,9 @@ def __init__( socket_timeout=30, ): env_host = getenv( - "DATABASE_VIKINGMEM_BASE_URL", default_value=None, allow_false_values=True + "DATABASE_VIKINGMEM_BASE_URL", + default_value=None, + allow_false_values=True, ) if env_host: if env_host.startswith("http://"): @@ -85,7 +90,9 @@ def __init__( self.get_body("Ping", {}, json.dumps({})) except Exception as e: raise VikingDBMemoryException( - 1000028, "missed", "host or region is incorrect: {}".format(str(e)) + 1000028, + "missed", + "host or region is incorrect: {}".format(str(e)), ) from None def setHeader(self, header): @@ -118,49 +125,70 @@ def get_api_info(): "/api/memory/collection/create", {}, {}, - {"Accept": "application/json", "Content-Type": "application/json"}, + { + "Accept": "application/json", + "Content-Type": "application/json", + }, ), "GetCollection": ApiInfo( "POST", "/api/memory/collection/info", {}, {}, - {"Accept": "application/json", "Content-Type": "application/json"}, + { + "Accept": "application/json", + "Content-Type": "application/json", + }, ), "DropCollection": ApiInfo( "POST", "/api/memory/collection/delete", {}, {}, - {"Accept": "application/json", "Content-Type": "application/json"}, + { + "Accept": "application/json", + "Content-Type": "application/json", + }, ), "UpdateCollection": ApiInfo( "POST", "/api/memory/collection/update", {}, {}, - {"Accept": "application/json", "Content-Type": "application/json"}, + { + "Accept": "application/json", + "Content-Type": "application/json", + }, ), "SearchMemory": ApiInfo( "POST", "/api/memory/search", {}, {}, - {"Accept": "application/json", "Content-Type": "application/json"}, + { + "Accept": "application/json", + "Content-Type": "application/json", + }, ), "AddMessages": ApiInfo( "POST", "/api/memory/messages/add", {}, {}, - {"Accept": "application/json", "Content-Type": "application/json"}, + { + "Accept": "application/json", + "Content-Type": "application/json", + }, ), "Ping": ApiInfo( "GET", "/api/memory/ping", {}, {}, - {"Accept": "application/json", "Content-Type": "application/json"}, + { + "Accept": "application/json", + "Content-Type": "application/json", + }, ), } return api_info @@ -199,7 +227,9 @@ def get_body_exception(self, api, params, body): res_json = json.loads(e.args[0].decode("utf-8")) except Exception as e: raise VikingDBMemoryException( - 1000028, "missed", "json load res error, res:{}".format(str(e)) + 1000028, + "missed", + "json load res error, res:{}".format(str(e)), ) from None code = res_json.get("code", 1000028) request_id = res_json.get("request_id", 1000028) @@ -223,7 +253,9 @@ def get_exception(self, api, params): res_json = json.loads(e.args[0].decode("utf-8")) except Exception as e: raise VikingDBMemoryException( - 1000028, "missed", "json load res error, res:{}".format(str(e)) + 1000028, + "missed", + "json load res error, res:{}".format(str(e)), ) from None code = res_json.get("code", 1000028) request_id = res_json.get("request_id", 1000028) @@ -241,6 +273,7 @@ def create_collection( self, collection_name, description="", + project="default", custom_event_type_schemas=[], custom_entity_type_schemas=[], builtin_event_types=[], @@ -248,6 +281,10 @@ def create_collection( ): params = { "CollectionName": collection_name, + "ProjectName": project, + "CollectionType": os.getenv( + "DATABASE_VIKINGMEM_COLLECTION_TYPE", "standard" + ), "Description": description, "CustomEventTypeSchemas": custom_event_type_schemas, "CustomEntityTypeSchemas": custom_entity_type_schemas, @@ -257,8 +294,8 @@ def create_collection( res = self.json("CreateCollection", {}, json.dumps(params)) return json.loads(res) - def get_collection(self, collection_name): - params = {"CollectionName": collection_name} + def get_collection(self, collection_name, project="default"): + params = {"CollectionName": collection_name, "ProjectName": project} res = self.json("GetCollection", {}, json.dumps(params)) return json.loads(res) diff --git a/veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py b/veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py index 89df8088..b5f905da 100644 --- a/veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py +++ b/veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py @@ -21,6 +21,8 @@ from pydantic import Field from typing_extensions import override +from vikingdb import IAM +from vikingdb.memory import VikingMem import veadk.config # noqa E401 from veadk.auth.veauth.utils import get_credential_from_vefaas_iam @@ -30,9 +32,6 @@ from veadk.memory.long_term_memory_backends.base_backend import ( BaseLongTermMemoryBackend, ) -from vikingdb import IAM -from vikingdb.memory import VikingMem - from veadk.utils.logger import get_logger logger = get_logger(__name__) @@ -49,9 +48,16 @@ class VikingDBLTMBackend(BaseLongTermMemoryBackend): session_token: str = "" - region: str = "cn-beijing" + region: str = Field( + default_factory=lambda: os.getenv("DATABASE_VIKINGMEM_REGION") or "cn-beijing" + ) """VikingDB memory region""" + volcengine_project: str = Field( + default_factory=lambda: os.getenv("DATABASE_VIKINGMEM_PROJECT") or "default" + ) + """VikingDB memory project""" + memory_type: list[str] = Field(default_factory=list) def model_post_init(self, __context: Any) -> None: @@ -87,7 +93,9 @@ def precheck_index_naming(self): def _collection_exist(self) -> bool: try: client = self._get_client() - client.get_collection(collection_name=self.index) + client.get_collection( + collection_name=self.index, project=self.volcengine_project + ) logger.info(f"Collection {self.index} exist.") return True except Exception: @@ -101,6 +109,7 @@ def _create_collection(self) -> None: client = self._get_client() response = client.create_collection( collection_name=self.index, + project=self.volcengine_project, description="Created by Volcengine Agent Development Kit VeADK", builtin_event_types=self.memory_type, ) @@ -156,7 +165,9 @@ def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool: ) client = self._get_sdk_client() - collection = client.get_collection(collection_name=self.index) + collection = client.get_collection( + collection_name=self.index, project_name=self.volcengine_project + ) response = collection.add_session( session_id=session_id, messages=messages, @@ -181,7 +192,9 @@ def search_memory( ) client = self._get_sdk_client() - collection = client.get_collection(collection_name=self.index) + collection = client.get_collection( + collection_name=self.index, project_name=self.volcengine_project + ) response = collection.search_memory( query=query, filter=filter,