Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 51 additions & 14 deletions veadk/integrations/ve_viking_db_memory/ve_viking_db_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
# limitations under the License.

import json
import os
import threading
from veadk.utils.misc import getenv

from volcengine.ApiInfo import ApiInfo
from volcengine.auth.SignerV4 import SignerV4
from volcengine.base.Service import Service
from volcengine.Credentials import Credentials
from volcengine.ServiceInfo import ServiceInfo

from veadk.utils.misc import getenv


class VikingDBMemoryException(Exception):
def __init__(self, code, request_id, message=None):
Expand Down Expand Up @@ -56,7 +59,9 @@ def __init__(
socket_timeout=30,
):
env_host = getenv(
"DATABASE_VIKINGMEM_BASE_URL", default_value=None, allow_false_values=True
"DATABASE_VIKINGMEM_BASE_URL",
default_value=None,
allow_false_values=True,
)
if env_host:
if env_host.startswith("http://"):
Expand Down Expand Up @@ -85,7 +90,9 @@ def __init__(
self.get_body("Ping", {}, json.dumps({}))
except Exception as e:
raise VikingDBMemoryException(
1000028, "missed", "host or region is incorrect: {}".format(str(e))
1000028,
"missed",
"host or region is incorrect: {}".format(str(e)),
) from None

def setHeader(self, header):
Expand Down Expand Up @@ -118,49 +125,70 @@ def get_api_info():
"/api/memory/collection/create",
{},
{},
{"Accept": "application/json", "Content-Type": "application/json"},
{
"Accept": "application/json",
"Content-Type": "application/json",
},
),
"GetCollection": ApiInfo(
"POST",
"/api/memory/collection/info",
{},
{},
{"Accept": "application/json", "Content-Type": "application/json"},
{
"Accept": "application/json",
"Content-Type": "application/json",
},
),
"DropCollection": ApiInfo(
"POST",
"/api/memory/collection/delete",
{},
{},
{"Accept": "application/json", "Content-Type": "application/json"},
{
"Accept": "application/json",
"Content-Type": "application/json",
},
),
"UpdateCollection": ApiInfo(
"POST",
"/api/memory/collection/update",
{},
{},
{"Accept": "application/json", "Content-Type": "application/json"},
{
"Accept": "application/json",
"Content-Type": "application/json",
},
),
"SearchMemory": ApiInfo(
"POST",
"/api/memory/search",
{},
{},
{"Accept": "application/json", "Content-Type": "application/json"},
{
"Accept": "application/json",
"Content-Type": "application/json",
},
),
"AddMessages": ApiInfo(
"POST",
"/api/memory/messages/add",
{},
{},
{"Accept": "application/json", "Content-Type": "application/json"},
{
"Accept": "application/json",
"Content-Type": "application/json",
},
),
"Ping": ApiInfo(
"GET",
"/api/memory/ping",
{},
{},
{"Accept": "application/json", "Content-Type": "application/json"},
{
"Accept": "application/json",
"Content-Type": "application/json",
},
),
}
return api_info
Expand Down Expand Up @@ -199,7 +227,9 @@ def get_body_exception(self, api, params, body):
res_json = json.loads(e.args[0].decode("utf-8"))
except Exception as e:
raise VikingDBMemoryException(
1000028, "missed", "json load res error, res:{}".format(str(e))
1000028,
"missed",
"json load res error, res:{}".format(str(e)),
) from None
code = res_json.get("code", 1000028)
request_id = res_json.get("request_id", 1000028)
Expand All @@ -223,7 +253,9 @@ def get_exception(self, api, params):
res_json = json.loads(e.args[0].decode("utf-8"))
except Exception as e:
raise VikingDBMemoryException(
1000028, "missed", "json load res error, res:{}".format(str(e))
1000028,
"missed",
"json load res error, res:{}".format(str(e)),
) from None
code = res_json.get("code", 1000028)
request_id = res_json.get("request_id", 1000028)
Expand All @@ -241,13 +273,18 @@ def create_collection(
self,
collection_name,
description="",
project="default",
custom_event_type_schemas=[],
custom_entity_type_schemas=[],
builtin_event_types=[],
builtin_entity_types=[],
):
params = {
"CollectionName": collection_name,
"ProjectName": project,
"CollectionType": os.getenv(
"DATABASE_VIKINGMEM_COLLECTION_TYPE", "standard"
),
"Description": description,
"CustomEventTypeSchemas": custom_event_type_schemas,
"CustomEntityTypeSchemas": custom_entity_type_schemas,
Expand All @@ -257,8 +294,8 @@ def create_collection(
res = self.json("CreateCollection", {}, json.dumps(params))
return json.loads(res)

def get_collection(self, collection_name):
params = {"CollectionName": collection_name}
def get_collection(self, collection_name, project="default"):
params = {"CollectionName": collection_name, "ProjectName": project}
res = self.json("GetCollection", {}, json.dumps(params))
return json.loads(res)

Expand Down
27 changes: 20 additions & 7 deletions veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from pydantic import Field
from typing_extensions import override
from vikingdb import IAM
from vikingdb.memory import VikingMem

import veadk.config # noqa E401
from veadk.auth.veauth.utils import get_credential_from_vefaas_iam
Expand All @@ -30,9 +32,6 @@
from veadk.memory.long_term_memory_backends.base_backend import (
BaseLongTermMemoryBackend,
)
from vikingdb import IAM
from vikingdb.memory import VikingMem

from veadk.utils.logger import get_logger

logger = get_logger(__name__)
Expand All @@ -49,9 +48,16 @@ class VikingDBLTMBackend(BaseLongTermMemoryBackend):

session_token: str = ""

region: str = "cn-beijing"
region: str = Field(
default_factory=lambda: os.getenv("DATABASE_VIKINGMEM_REGION") or "cn-beijing"
)
"""VikingDB memory region"""

volcengine_project: str = Field(
default_factory=lambda: os.getenv("DATABASE_VIKINGMEM_PROJECT") or "default"
)
"""VikingDB memory project"""

memory_type: list[str] = Field(default_factory=list)

def model_post_init(self, __context: Any) -> None:
Expand Down Expand Up @@ -87,7 +93,9 @@ def precheck_index_naming(self):
def _collection_exist(self) -> bool:
try:
client = self._get_client()
client.get_collection(collection_name=self.index)
client.get_collection(
collection_name=self.index, project=self.volcengine_project
)
logger.info(f"Collection {self.index} exist.")
return True
except Exception:
Expand All @@ -101,6 +109,7 @@ def _create_collection(self) -> None:
client = self._get_client()
response = client.create_collection(
collection_name=self.index,
project=self.volcengine_project,
description="Created by Volcengine Agent Development Kit VeADK",
builtin_event_types=self.memory_type,
)
Expand Down Expand Up @@ -156,7 +165,9 @@ def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
)

client = self._get_sdk_client()
collection = client.get_collection(collection_name=self.index)
collection = client.get_collection(
collection_name=self.index, project_name=self.volcengine_project
)
response = collection.add_session(
session_id=session_id,
messages=messages,
Expand All @@ -181,7 +192,9 @@ def search_memory(
)

client = self._get_sdk_client()
collection = client.get_collection(collection_name=self.index)
collection = client.get_collection(
collection_name=self.index, project_name=self.volcengine_project
)
response = collection.search_memory(
query=query,
filter=filter,
Expand Down