Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions vectordb_bench/backend/clients/envector/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ class EnVectorTypedDict(TypedDict):
str,
click.option("--eval-mode", help="Evaluation mode", type=click.Choice(["mm", "rmp"]), default="mm"),
]


class EnVectorFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict): ...


@cli.command(name="envectorflat")
@click_parameter_decorators_from_typed_dict(EnVectorFlatIndexTypedDict)
def EnVectorFlat(**parameters: Unpack[EnVectorFlatIndexTypedDict]):
from .config import FlatIndexConfig, EnVectorConfig
from .config import EnVectorConfig, FlatIndexConfig

run(
db=DBTYPE,
Expand All @@ -46,7 +46,7 @@ def EnVectorFlat(**parameters: Unpack[EnVectorFlatIndexTypedDict]):
)


class EnVectorIVFFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict):
class EnVectorIVFFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict):
nlist: Annotated[
int,
click.option("--nlist", type=int, help="nlist for IVF index", default=250),
Expand Down Expand Up @@ -76,7 +76,7 @@ class EnVectorIVFFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict):
@cli.command(name="envectorivfflat")
@click_parameter_decorators_from_typed_dict(EnVectorIVFFlatIndexTypedDict)
def EnVectorIVFFlat(**parameters: Unpack[EnVectorIVFFlatIndexTypedDict]):
from .config import IVFFlatIndexConfig, EnVectorConfig
from .config import EnVectorConfig, IVFFlatIndexConfig

run(
db=DBTYPE,
Expand All @@ -87,7 +87,7 @@ def EnVectorIVFFlat(**parameters: Unpack[EnVectorIVFFlatIndexTypedDict]):
index_params={"nlist": parameters["nlist"], "nprobe": parameters["nprobe"]},
),
db_case_config=IVFFlatIndexConfig(
nlist=parameters["nlist"],
nlist=parameters["nlist"],
nprobe=parameters["nprobe"],
train_centroids=parameters["train_centroids"],
centroids_path=parameters["centroids_path"],
Expand Down
8 changes: 4 additions & 4 deletions vectordb_bench/backend/clients/envector/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydantic import BaseModel, SecretStr

from ..api import DBCaseConfig, DBConfig, IndexType, MetricType, SQType
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType


class EnVectorConfig(DBConfig):
Expand Down Expand Up @@ -67,10 +67,10 @@ class IVFFlatIndexConfig(EnVectorIndexConfig, DBCaseConfig):
nlist: int = 0
nprobe: int = 0
eval_mode: str = "mm"
train_centroids: bool = False # whether to train centroids before inserting data
train_centroids: bool = False # whether to train centroids before inserting data
centroids_path: str | None = None # path to centroids file
is_vct: bool = False # whether use VCT index
vct_path: str | None = None # path to VCT index file
is_vct: bool = False # whether use VCT index
vct_path: str | None = None # path to VCT index file

def index_param(self) -> dict:
return {
Expand Down
49 changes: 22 additions & 27 deletions vectordb_bench/backend/clients/envector/envector.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
"""Wrapper around the EnVector vector database over VectorDB"""

from typing import Any, Dict

import logging
import os
from collections.abc import Iterable
from contextlib import contextmanager
import pickle

import numpy as np
from typing import Any

import es2
import numpy as np

from vectordb_bench.backend.filter import Filter, FilterOp

from ..api import VectorDB
from .config import EnVectorIndexConfig


log = logging.getLogger(__name__)


Expand Down Expand Up @@ -45,8 +41,8 @@ def __init__(
self.case_config = db_case_config
self.collection_name = collection_name

self.batch_size = 128 * 32 # default batch size for insertions, can be modified for IVF_FLAT
self.batch_size = 128 * 32 # default batch size for insertions, can be modified for IVF_FLAT

self._primary_field = "pk"
self._scalar_id_field = "id"
self._scalar_label_field = "label"
Expand All @@ -57,23 +53,23 @@ def __init__(
self.col: es2.Index | None = None

self.is_vct: bool = False
self.vct_params: Dict[str, Any] = {}
kwargs: Dict[str, Any] = {}
self.vct_params: dict[str, Any] = {}
kwargs: dict[str, Any] = {}

es2.init(
address=self.db_config.get("uri"),
key_path=self.db_config.get("key_path"),
address=self.db_config.get("uri"),
key_path=self.db_config.get("key_path"),
key_id=self.db_config.get("key_id"),
eval_mode=self.case_config.eval_mode,
)
if drop_old:
log.info(f"{self.name} client drop_old index: {self.collection_name}")
if self.collection_name in es2.get_index_list():
log.info(f"{self.name} client drop_old index: {self.collection_name}")
if self.collection_name in es2.get_index_list():
es2.drop_index(self.collection_name)

# Create the collection
log.info(f"{self.name} create index: {self.collection_name}")

if self.collection_name in es2.get_index_list():
log.info(f"{self.name} index {self.collection_name} already exists, skip creating")
self.is_vct = self.case_config.index_param().get("is_vct", False)
Expand All @@ -83,21 +79,21 @@ def __init__(
index_param = self.case_config.index_param().get("params", {})
index_type = index_param.get("index_type", "FLAT")
train_centroids = self.case_config.index_param().get("train_centroids", False)

if index_type == "IVF_FLAT" and train_centroids:

centroid_path = self.case_config.index_param().get("centroids_path", None)
self.is_vct = self.case_config.index_param().get("is_vct", False)
log.debug(f"IS_VCT: {self.is_vct}")

if centroid_path is not None:
if not os.path.exists(centroid_path):
raise FileNotFoundError(f"Centroid file {centroid_path} not found for IVF_FLAT index training.")

# load trained centroids from file
log.debug(f"Centroids: {centroid_path}")
centroids = np.load(centroid_path)
log.info(f"{self.name} loaded centroids from {centroid_path} for IVF_FLAT index training.")
log.info(f"{self.name} loaded centroids from {centroid_path} for IVF_FLAT index training.")

# set centroids for index creation
index_param["centroids"] = centroids.tolist()
Expand Down Expand Up @@ -190,7 +186,7 @@ def insert_embeddings(
# use the first insert_embeddings to init collection
assert self.col is not None
assert len(embeddings) == len(metadata)

log.debug(f"IS_VCT: {self.is_vct}")

insert_count = 0
Expand Down Expand Up @@ -229,7 +225,7 @@ def search_embedding(
output_fields=["metadata"],
search_params=self.case_config.search_param().get("search_params", {}),
)

else:
# Perform the search.
res = self.col.search(
Expand All @@ -249,9 +245,8 @@ def search_embedding(
log.debug(f"Search results: {res[0][:1]}") # Log first 1 results for debugging
if len(res) > 0 and len(res[0]) > 0:
return [int(result["metadata"]) for result in res[0] if "metadata" in result]
else:
log.warning(f"Unexpected result structure: {res}")
return []
log.warning(f"Unexpected result structure: {res}")
return []
Comment thread
SongHyeopPark marked this conversation as resolved.
Outdated

except Exception as e:
log.error(f"Search failed: {e}")
Expand Down
4 changes: 2 additions & 2 deletions vectordb_bench/log_util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
import os
from logging import config
from pathlib import Path
import os


def init(log_level: str):
os.environ["TQDM_DISABLE"] = "1"

# Create logs directory if it doesn't exist
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
Expand Down
Loading