diff --git a/.env.example b/.env.example index 618cac556..f00ceed24 100644 --- a/.env.example +++ b/.env.example @@ -3,9 +3,9 @@ # LOG_NAME= # TIMEZONE= -# NUM_PER_BATCH= +NUM_PER_BATCH=4096 # DEFAULT_DATASET_URL= -DATASET_LOCAL_DIR="/tmp/vectordb_bench/dataset" +DATASET_LOCAL_DIR="/data/vectordb_bench/dataset" # DROP_OLD = True diff --git a/README.md b/README.md index 9ee685770..7c1e9f46d 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,11 @@ +# enVector with ANN (GAS) in VectorDBBench + +The guide on how to use enVector with ANN index in VectorDBBench is available in [README_ENVECTOR.md](README_ENVECTOR.md). + +The followings are the original contents of README in VectorDBBench: + +--- + # VectorDBBench(VDBBench): A Benchmark Tool for VectorDB [![version](https://img.shields.io/pypi/v/vectordb-bench.svg?color=blue)](https://pypi.org/project/vectordb-bench/) @@ -422,6 +430,9 @@ python -m vectordb_bench OR: +If you are using [dev container](https://code.visualstudio.com/docs/devcontainers/containers), create +the following dataset directory first: + ```shell init_bench ``` diff --git a/README_ENVECTOR.md b/README_ENVECTOR.md new file mode 100644 index 000000000..68bd785de --- /dev/null +++ b/README_ENVECTOR.md @@ -0,0 +1,132 @@ +# enVector with ANN (GAS) in VectorDBBench + +This guide demonstrates how to use enVector with an ANN index in VectorDBBench. + +Basic usage of enVector with VectorDBBench follows the standard procedure for [VectorDBBench](https://github.com/zilliztech/VectorDBBench). + +## Structure + +```bash +. +├── centroids +│ └── embeddinggemma-300m +│ ├── centroids.npy # centroids file for ANN +│ └── tree_info.pkl # tree metadata for ANN +├── dataset +│ └── pubmed768d400k # VectorDB ANN benchmark dataset +│ ├── neighbors.parquet +│ ├── test.npy +│ └── train.pkl +├── README_ENVECTOR.md +├── scripts + ├── run_benchmark.sh # benchmark script + ├── envector_pubmed_config.yml # benchmark config file + └── prepare_dataset.py # download and prepare ground truth neighbors for dataset +``` + +## Prerequisites + +### Install Python Dependencies +```bash +# 1. Create your environment +python -m venv .venv +source .venv/bin/activate + +# 2. Install VectorDBBench +pip install -e . + +# 3. Install es2 +pip install es2==1.2.0a4 +``` + +### Prepare dataset + +Prepare the following artifacts for the ANN benchmark with `scripts/prepare_dataset.py`: + +- download datasets from HuggingFace +- prepare ground-truth neighbors +- download centroids and tree metadata for the GAS index for corresponding to the embedding model + +For the ANN benchmark, we provide two datasets via HuggingFace: +- PUBMED768D400K: [cryptolab-playground/pubmed-arxiv-abstract-embedding-gemma-300m](https://huggingface.co/datasets/cryptolab-playground/pubmed-arxiv-abstract-embedding-gemma-300m) +- BLOOMBERG768D368K: [cryptolab-playground/Bloomberg-Financial-News-embedding-gemma-300m](https://huggingface.co/datasets/cryptolab-playground/Bloomberg-Financial-News-embedding-gemma-300m) + +Also, we provide centroids and tree metadata for the corresponding embedding model used in the ANN benchmark: +- GAS Centroids: [cryptolab-playground/gas-centroids](https://huggingface.co/datasets/cryptolab-playground/gas-centroids) + +To prepare dataset, run the following command as example: + +```bash +# Prepare dataset +python ./scripts/prepare_dataset.py \ + -d cryptolab-playground/pubmed-arxiv-abstract-embedding-gemma-300m \ + -e embeddinggemma-300m +``` + +Then, you can find the following generated files: + +```bash +. +├── centroids +│ └── embeddinggemma-300m +│ ├── centroids.npy +│ └── tree_info.pkl +└── dataset + └── pubmed768d400k + ├── neighbors.parquet + ├── test.npy + └── train.pkl +``` + +### Prepare enVector Server + +To run enVector server with ANN, please refer to the [enVector Deployment repository](https://github.com/CryptoLabInc/envector-deployment). +For example, you can start the server with the following command: + +```bash +# Start enVector server +git clone https://github.com/CryptoLabInc/envector-deployment +cd envector-deployment/docker-compose +./start_envector.sh +``` + +We provide four enVector Docker Images: +- `cryptolabinc/es2e:v1.2.0-alpha.4` +- `cryptolabinc/es2b:v1.2.0-alpha.4` +- `cryptolabinc/es2o:v1.2.0-alpha.4` +- `cryptolabinc/es2c:v1.2.0-alpha.4` + +### Set Environment Variables + +```bash +# Set environment variables +export DATASET_LOCAL_DIR="./dataset" +export NUM_PER_BATCH=4096 +``` + +## Run Benchmark + +Refer to `./scripts/run_benchmark.sh` or `./scripts/envector_benchmark_config.yml` for benchmarks with enVector with ANN (VCT), or use the following command: + +```bash +export NUM_PER_BATCH=500000 # set to the database size for efficiency with IVF_FLAT +python -m vectordb_bench.cli.vectordbbench envectorivfflat \ + --uri "localhost:50050" \ + --eval-mode mm \ + --case-type PerformanceCustomDataset \ + --db-label "PUBMED768D400K-IVF" \ + --custom-case-name PUBMED768D400K \ + --custom-dataset-name PUBMED768D400K \ + --custom-dataset-dir "" \ + --custom-dataset-size 400335 \ + --custom-dataset-dim 768 \ + --custom-dataset-file-count 1 \ + --custom-dataset-with-gt \ + --skip-custom-dataset-use-shuffled \ + --train-centroids True \ + --is-vct True \ + --centroids-path "./centroids/embeddinggemma-300m/centroids.npy" \ + --vct-path "./centroids/embeddinggemma-300m/tree_info.pkl" \ + --nlist 32768 \ + --nprobe 6 +``` \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a922c1fb5..79f70759e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,12 @@ dependencies = [ "pydantic None: + """Download dataset from Huggingface and save as Parquet files.""" + # load dataset + ds = load_dataset(dataset_name) + train = ds["train"].to_pandas() + test = ds["test"].to_pandas() + + # write to parquet + train_table = pa.Table.from_pandas(train) + pq.write_table(train_table, f"{output_dir}/train.parquet") + + test_table = pa.Table.from_pandas(test) + pq.write_table(test_table, f"{output_dir}/test.parquet") + +def prepare_neighbors( + data_dir: str = "./dataset/pubmed768d400k", +) -> None: + """Prepare ground truth neighbors using brute-force flat search and save as Parquet.""" + # load dataset + train = pd.read_parquet(f"{data_dir}/train.parquet") + test = pd.read_parquet(f"{data_dir}/test.parquet") + + train = np.stack(train["emb"].to_list()).astype("float32") + test = np.stack(test["emb"].to_list()).astype("float32") + dim = train.shape[1] + + # flat search + index = faiss.IndexFlatIP(dim) + index.add(train) + + k = len(test) + distances, indices = index.search(test, k) + print(distances.shape, indices.shape) + + # save flat search result as neighbors + df = pd.DataFrame({ + "id": np.arange(len(indices)), + "neighbors_id": indices.tolist() + }) + + table = pa.Table.from_pandas(df) + pq.write_table(table, f"{data_dir}/neighbors.parquet") + +def download_centroids(embedding_model: str, dataset_dir: str) -> None: + """Download pre-computed centroids and tree info for GAS VCT index.""" + + if embedding_model != "embeddinggemma-300m": + raise ValueError(f"Centroids for {embedding_model} currently not available.") + + # https://huggingface.co/datasets/cryptolab-playground/gas-centroids + dataset_link = f"https://huggingface.co/datasets/cryptolab-playground/gas-centroids/resolve/main/{embedding_model}" + + # download + os.makedirs(os.path.join(dataset_dir, embedding_model), exist_ok=True) + wget.download(f"{dataset_link}/centroids.npy", out=os.path.join(dataset_dir, embedding_model, "centroids.npy")) + wget.download(f"{dataset_link}/tree_info.pkl", out=os.path.join(dataset_dir, embedding_model, "tree_info.pkl")) + + +if __name__ == "__main__": + args = get_args() + os.makedirs(args.dataset_dir, exist_ok=True) + + download_dataset(args.dataset_name, args.dataset_dir) + prepare_neighbors(args.dataset_dir) + download_centroids(args.embedding_model, args.centroids_dir) diff --git a/scripts/run_benchmark.sh b/scripts/run_benchmark.sh new file mode 100755 index 000000000..37037df90 --- /dev/null +++ b/scripts/run_benchmark.sh @@ -0,0 +1,75 @@ +#!/bin/bash + +set -euo pipefail + +export DATASET_LOCAL_DIR="./dataset" +export NUM_PER_BATCH=4096 + +CENTROID_PATH=centroids/embeddinggemma-300m/centroids.npy +VCT_PATH=centroids/embeddinggemma-300m/tree_info.pkl +ENVECTOR_URI="localhost:50050" +REQUESTED_TYPE="" + +while [[ $# -gt 0 ]]; do + case "$1" in + --type) + REQUESTED_TYPE="${2:-}" + shift 2 + ;; + --type=*) + REQUESTED_TYPE="${1#--type=}" + shift + ;; + *) + echo "Unknown option: $1" >&2 + exit 1 + ;; + esac +done + +case "$REQUESTED_TYPE" in + ""|flat|ivf) ;; + *) + echo "Invalid --type: $REQUESTED_TYPE (expected: flat or ivf)" >&2 + exit 1 + ;; +esac +COMMON_ARGS=( + --uri "$ENVECTOR_URI" + --eval-mode mm + --case-type PerformanceCustomDataset + --custom-case-name PUBMED768D400K + --custom-dataset-name PUBMED768D400K + --custom-dataset-dir "" + --custom-dataset-size 400335 + --custom-dataset-dim 768 + --custom-dataset-file-count 1 + --custom-dataset-with-gt + --skip-custom-dataset-use-shuffled + --k 10 +) + +run_case() { + local engine=$1 + local label=$2 + shift 2 + python -m vectordb_bench.cli.vectordbbench "$engine" \ + "${COMMON_ARGS[@]}" \ + --db-label "$label" \ + "$@" +} + +if [[ -z "$REQUESTED_TYPE" || "$REQUESTED_TYPE" == "flat" ]]; then + run_case envectorflat "PUBMED768D400K-FLAT" +fi + +if [[ -z "$REQUESTED_TYPE" || "$REQUESTED_TYPE" == "ivf" ]]; then + export NUM_PER_BATCH=500000 # set database size for efficiency + run_case envectorivfflat "PUBMED768D400K-IVF" \ + --is-vct True \ + --train-centroids True \ + --centroids-path "$CENTROID_PATH" \ + --vct-path "$VCT_PATH" \ + --nlist 32768 \ + --nprobe 6 +fi diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index 79a6f964a..d23c32794 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -51,6 +51,7 @@ class DB(Enum): OceanBase = "OceanBase" S3Vectors = "S3Vectors" Hologres = "Alibaba Cloud Hologres" + EnVector = "EnVector" @property def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915 @@ -200,6 +201,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915 return Hologres + if self == DB.EnVector: + from .envector.envector import EnVector + + return EnVector + msg = f"Unknown DB: {self.name}" raise ValueError(msg) @@ -351,6 +357,11 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901, PLR0915 return HologresConfig + if self == DB.EnVector: + from .envector.config import EnVectorConfig + + return EnVectorConfig + msg = f"Unknown DB: {self.name}" raise ValueError(msg) @@ -477,6 +488,11 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912 return HologresIndexConfig + if self == DB.EnVector: + from .envector.config import EnVectorIndexConfig + + return EnVectorIndexConfig + # DB.Pinecone, DB.Chroma, DB.Redis return EmptyDBCaseConfig diff --git a/vectordb_bench/backend/clients/envector/cli.py b/vectordb_bench/backend/clients/envector/cli.py new file mode 100644 index 000000000..b14d509b4 --- /dev/null +++ b/vectordb_bench/backend/clients/envector/cli.py @@ -0,0 +1,98 @@ +from typing import Annotated, TypedDict, Unpack + +import click +from pydantic import SecretStr + +from vectordb_bench.backend.clients import DB +from vectordb_bench.cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) + +DBTYPE = DB.EnVector + + +class EnVectorTypedDict(TypedDict): + uri: Annotated[ + str, + click.option("--uri", type=str, help="uri connection string", required=True), + ] + eval_mode: Annotated[ + str, + click.option("--eval-mode", help="Evaluation mode", type=click.Choice(["mm", "rmp"]), default="mm"), + ] + + +class EnVectorFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict): ... + + +@cli.command(name="envectorflat") +@click_parameter_decorators_from_typed_dict(EnVectorFlatIndexTypedDict) +def EnVectorFlat(**parameters: Unpack[EnVectorFlatIndexTypedDict]): + from .config import FlatIndexConfig, EnVectorConfig + + run( + db=DBTYPE, + db_config=EnVectorConfig( + db_label=parameters["db_label"], + uri=SecretStr(parameters["uri"]), + eval_mode=parameters["eval_mode"], + index_params={}, + ), + db_case_config=FlatIndexConfig(), + **parameters, + ) + + +class EnVectorIVFFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict): + nlist: Annotated[ + int, + click.option("--nlist", type=int, help="nlist for IVF index", default=250), + ] + nprobe: Annotated[ + int, + click.option("--nprobe", type=int, help="nprobe for IVF index", default=6), + ] + train_centroids: Annotated[ + bool, + click.option("--train-centroids", type=bool, help="train IVF centroids", default=False), + ] + centroids_path: Annotated[ + str, + click.option("--centroids-path", type=str, help="path to centroids for IVF index", default=None), + ] + is_vct: Annotated[ + bool, + click.option("--is-vct", type=bool, help="whether use VCT index", default=False), + ] + vct_path: Annotated[ + str, + click.option("--vct-path", type=str, help="path to VCT index file", default=None), + ] + + +@cli.command(name="envectorivfflat") +@click_parameter_decorators_from_typed_dict(EnVectorIVFFlatIndexTypedDict) +def EnVectorIVFFlat(**parameters: Unpack[EnVectorIVFFlatIndexTypedDict]): + from .config import IVFFlatIndexConfig, EnVectorConfig + + run( + db=DBTYPE, + db_config=EnVectorConfig( + db_label=parameters["db_label"], + uri=SecretStr(parameters["uri"]), + eval_mode=parameters["eval_mode"], + index_params={"nlist": parameters["nlist"], "nprobe": parameters["nprobe"]}, + ), + db_case_config=IVFFlatIndexConfig( + nlist=parameters["nlist"], + nprobe=parameters["nprobe"], + train_centroids=parameters["train_centroids"], + centroids_path=parameters["centroids_path"], + is_vct=parameters["is_vct"], + vct_path=parameters["vct_path"], + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/envector/config.py b/vectordb_bench/backend/clients/envector/config.py new file mode 100644 index 000000000..08db61bbd --- /dev/null +++ b/vectordb_bench/backend/clients/envector/config.py @@ -0,0 +1,97 @@ +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType, SQType + + +class EnVectorConfig(DBConfig): + uri: SecretStr = SecretStr("http://localhost:50050") + key_path: str = "keys" + key_id: str = "default_key" + + def to_dict(self) -> dict: + return { + "uri": self.uri.get_secret_value(), + "key_path": self.key_path, + "key_id": self.key_id, + } + + +class EnVectorIndexConfig(BaseModel): + """Base config for envector""" + + index: IndexType + metric_type: MetricType | None = None + use_partition_key: bool = True # for label-filter + + @property + def is_gpu_index(self) -> bool: + return self.index in [ + IndexType.GPU_CAGRA, + IndexType.GPU_IVF_FLAT, + IndexType.GPU_IVF_PQ, + IndexType.GPU_BRUTE_FORCE, + ] + + def parse_metric(self) -> str: + if not self.metric_type: + return "" + + if self.is_gpu_index and self.metric_type == MetricType.COSINE: + return MetricType.L2.value + return self.metric_type.value + + +class FlatIndexConfig(EnVectorIndexConfig, DBCaseConfig): + index: IndexType = IndexType.Flat + metric_type: MetricType = MetricType.COSINE # envector supports cosine similarity only + eval_mode: str = "mm" # default eval_mode + + def index_param(self) -> dict: + return { + "metric_type": "COSINE", + "index_type": self.index.value, + "eval_mode": self.eval_mode, + "params": {"index_type": "FLAT"}, + } + + def search_param(self) -> dict: + return { + "metric_type": "COSINE", + "search_params": {}, + } + + +class IVFFlatIndexConfig(EnVectorIndexConfig, DBCaseConfig): + index: IndexType = IndexType.IVFFlat + metric_type: MetricType = MetricType.COSINE # envector supports cosine similarity only + nlist: int = 0 + nprobe: int = 0 + eval_mode: str = "mm" + train_centroids: bool = False # whether to train centroids before inserting data + centroids_path: str | None = None # path to centroids file + is_vct: bool = False # whether use VCT index + vct_path: str | None = None # path to VCT index file + + def index_param(self) -> dict: + return { + "metric_type": "COSINE", + "index_type": self.index.value, + "eval_mode": self.eval_mode, + "params": {"index_type": "IVF_FLAT", "nlist": self.nlist, "default_nprobe": self.nprobe}, + "train_centroids": self.train_centroids, + "centroids_path": self.centroids_path, + "is_vct": self.is_vct, + "vct_path": self.vct_path, + } + + def search_param(self) -> dict: + return { + "metric_type": "COSINE", + "search_params": {"nprobe": self.nprobe}, + } + + +_envector_case_config = { + IndexType.Flat: FlatIndexConfig, + IndexType.IVFFlat: IVFFlatIndexConfig, +} diff --git a/vectordb_bench/backend/clients/envector/envector.py b/vectordb_bench/backend/clients/envector/envector.py new file mode 100644 index 000000000..50b274289 --- /dev/null +++ b/vectordb_bench/backend/clients/envector/envector.py @@ -0,0 +1,258 @@ +"""Wrapper around the EnVector vector database over VectorDB""" + +from typing import Any, Dict + +import logging +import os +from collections.abc import Iterable +from contextlib import contextmanager +import pickle + +import numpy as np + +import es2 + +from vectordb_bench.backend.filter import Filter, FilterOp + +from ..api import VectorDB +from .config import EnVectorIndexConfig + + +log = logging.getLogger(__name__) + + +class EnVector(VectorDB): + supported_filter_types: list[FilterOp] = [ + FilterOp.NonFilter, + FilterOp.NumGE, + FilterOp.StrEqual, + ] + + def __init__( + self, + dim: int, + db_config: dict, + db_case_config: EnVectorIndexConfig, + collection_name: str = "vdbbench", + drop_old: bool = False, + name: str = "EnVector", + with_scalar_labels: bool = False, + **kwargs, + ): + """Initialize wrapper around the envector vector database.""" + self.name = name + self.db_config = db_config + self.case_config = db_case_config + self.collection_name = collection_name + + self.batch_size = 128 * 32 # default batch size for insertions, can be modified for IVF_FLAT + + self._primary_field = "pk" + self._scalar_id_field = "id" + self._scalar_label_field = "label" + self._vector_field = "vector" + self._vector_index_name = "vector_idx" + self._scalar_id_index_name = "id_sort_idx" + self._scalar_labels_index_name = "labels_idx" + self.col: es2.Index | None = None + + self.is_vct: bool = False + self.vct_params: Dict[str, Any] = {} + kwargs: Dict[str, Any] = {} + + es2.init( + address=self.db_config.get("uri"), + key_path=self.db_config.get("key_path"), + key_id=self.db_config.get("key_id"), + eval_mode=self.case_config.eval_mode, + ) + if drop_old: + log.info(f"{self.name} client drop_old index: {self.collection_name}") + if self.collection_name in es2.get_index_list(): + es2.drop_index(self.collection_name) + + # Create the collection + log.info(f"{self.name} create index: {self.collection_name}") + + if self.collection_name in es2.get_index_list(): + log.info(f"{self.name} index {self.collection_name} already exists, skip creating") + self.is_vct = self.case_config.index_param().get("is_vct", False) + log.debug(f"IS_VCT: {self.is_vct}") + + else: + index_param = self.case_config.index_param().get("params", {}) + index_type = index_param.get("index_type", "FLAT") + train_centroids = self.case_config.index_param().get("train_centroids", False) + + if index_type == "IVF_FLAT" and train_centroids: + + centroid_path = self.case_config.index_param().get("centroids_path", None) + self.is_vct = self.case_config.index_param().get("is_vct", False) + log.debug(f"IS_VCT: {self.is_vct}") + + if centroid_path is not None: + if not os.path.exists(centroid_path): + raise FileNotFoundError(f"Centroid file {centroid_path} not found for IVF_FLAT index training.") + + # load trained centroids from file + log.debug(f"Centroids: {centroid_path}") + centroids = np.load(centroid_path) + log.info(f"{self.name} loaded centroids from {centroid_path} for IVF_FLAT index training.") + + # set centroids for index creation + index_param["centroids"] = centroids.tolist() + + if self.is_vct: + # set VCT parameters if applicable + vct_path = self.case_config.index_param().get("vct_path", None) + log.debug(f"VCT: {vct_path}") + index_param["virtual_cluster"] = True + kwargs["tree_description"] = vct_path + self.is_vct = True + log.info(f"{self.name} VCT parameters set for IVF_FLAT index creation.") + + else: + raise ValueError("Centroids path must be provided for IVF_FLAT index training.") + + # set larger batch size for IVF_FLAT insertions + if index_type == "IVF_FLAT": + self.batch_size = int(os.environ.get("NUM_PER_BATCH", 500_000)) + log.debug( + f"Set EnVector IVF_FLAT insert batch size to {self.batch_size}. " + f"This should be the size of dataset for better performance when IVF_FLAT." + ) + + # create index after training centroids + es2.create_index( + index_name=self.collection_name, + dim=dim, + key_path=self.db_config.get("key_path"), + key_id=self.db_config.get("key_id"), + index_params=index_param, + eval_mode=self.case_config.eval_mode, + **kwargs, + ) + + es2.disconnect() + + @contextmanager + def init(self): + """ + Examples: + >>> with self.init(): + >>> self.insert_embeddings() + >>> self.search_embedding() + """ + es2.init( + address=self.db_config.get("uri"), + key_path=self.db_config.get("key_path"), + key_id=self.db_config.get("key_id"), + eval_mode=self.case_config.eval_mode, + ) + try: + self.col = es2.Index(self.collection_name) + if self.is_vct: + log.debug(f"VCT: {self.col.index_config.index_param.index_params["virtual_cluster"]}") + is_vct = self.case_config.index_param().get("is_vct", False) + assert self.is_vct == is_vct, "is_vct mismatch" + vct_path = self.case_config.index_param().get("vct_path", None) + self.col._load_virtual_cluster_from_pkl(vct_path) + yield + finally: + self.col = None + es2.disconnect() + + def create_index(self): + pass + + def _optimize(self): + pass + + def _post_insert(self): + pass + + def optimize(self, data_size: int | None = None): + assert self.col, "Please call self.init() before" + self._optimize() + + def need_normalize_cosine(self) -> bool: + """Whether this database need to normalize dataset to support COSINE""" + return True + + def insert_embeddings( + self, + embeddings: Iterable[list[float]], + metadata: list[int], + labels_data: list[str] | None = None, + **kwargs, + ) -> tuple[int, Exception]: + """Insert embeddings into EnVector. should call self.init() first""" + # use the first insert_embeddings to init collection + assert self.col is not None + assert len(embeddings) == len(metadata) + + log.debug(f"IS_VCT: {self.is_vct}") + + insert_count = 0 + try: + for batch_start_offset in range(0, len(embeddings), self.batch_size): + batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings)) + meta = [str(m) for m in metadata[batch_start_offset:batch_end_offset]] + vectors = embeddings[batch_start_offset:batch_end_offset] + if self.is_vct: + self.col.insert_vct(vectors, meta) + else: + self.col.insert(vectors, meta) + insert_count += len(vectors) + except Exception as e: + log.info(f"Failed to insert data: {e}") + return insert_count, e + return insert_count, None + + def prepare_filter(self, filters: Filter): + pass + + def search_embedding( + self, + query: list[float], + k: int = 10, + timeout: int | None = None, + ) -> list[int]: + """Perform a search on a query embedding and return results.""" + assert self.col is not None + + try: + if self.is_vct: + res = self.col.search_vct( + query=query, + top_k=k, + output_fields=["metadata"], + search_params=self.case_config.search_param().get("search_params", {}), + ) + + else: + # Perform the search. + res = self.col.search( + query=query, + top_k=k, + output_fields=["metadata"], + search_params=self.case_config.search_param().get("search_params", {}), + ) + + # Handle empty results + if not res or len(res) == 0: + log.warning(f"Empty search results for query with k={k}") + return [] + + # Extract metadata from results + # res structure: [[{id: X, score: Y, metadata: Z}, ...]] + log.debug(f"Search results: {res[0][:1]}") # Log first 1 results for debugging + if len(res) > 0 and len(res[0]) > 0: + return [int(result["metadata"]) for result in res[0] if "metadata" in result] + else: + log.warning(f"Unexpected result structure: {res}") + return [] + + except Exception as e: + log.error(f"Search failed: {e}") + return [] diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index dadf8de96..cc03ba040 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -1,6 +1,7 @@ from ..backend.clients.alloydb.cli import AlloyDBScaNN from ..backend.clients.aws_opensearch.cli import AWSOpenSearch from ..backend.clients.clickhouse.cli import Clickhouse +from ..backend.clients.envector.cli import EnVectorFlat, EnVectorIVFFlat from ..backend.clients.hologres.cli import HologresHGraph from ..backend.clients.lancedb.cli import LanceDB from ..backend.clients.mariadb.cli import MariaDBHNSW @@ -47,6 +48,8 @@ cli.add_command(HologresHGraph) cli.add_command(QdrantCloud) cli.add_command(QdrantLocal) +cli.add_command(EnVectorFlat) +cli.add_command(EnVectorIVFFlat) cli.add_command(BatchCli) diff --git a/vectordb_bench/log_util.py b/vectordb_bench/log_util.py index 6ca6ccabf..3bf7e4725 100644 --- a/vectordb_bench/log_util.py +++ b/vectordb_bench/log_util.py @@ -1,9 +1,12 @@ import logging from logging import config from pathlib import Path +import os def init(log_level: str): + os.environ["TQDM_DISABLE"] = "1" + # Create logs directory if it doesn't exist log_dir = Path("logs") log_dir.mkdir(exist_ok=True)