Skip to content

Commit 28061e9

Browse files
committed
fix
1 parent 9ad0a37 commit 28061e9

4 files changed

Lines changed: 33 additions & 38 deletions

File tree

vectordb_bench/backend/clients/envector/cli.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ class EnVectorTypedDict(TypedDict):
2323
str,
2424
click.option("--eval-mode", help="Evaluation mode", type=click.Choice(["mm", "rmp"]), default="mm"),
2525
]
26-
26+
2727

2828
class EnVectorFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict): ...
2929

3030

3131
@cli.command(name="envectorflat")
3232
@click_parameter_decorators_from_typed_dict(EnVectorFlatIndexTypedDict)
3333
def EnVectorFlat(**parameters: Unpack[EnVectorFlatIndexTypedDict]):
34-
from .config import FlatIndexConfig, EnVectorConfig
34+
from .config import EnVectorConfig, FlatIndexConfig
3535

3636
run(
3737
db=DBTYPE,
@@ -46,7 +46,7 @@ def EnVectorFlat(**parameters: Unpack[EnVectorFlatIndexTypedDict]):
4646
)
4747

4848

49-
class EnVectorIVFFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict):
49+
class EnVectorIVFFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict):
5050
nlist: Annotated[
5151
int,
5252
click.option("--nlist", type=int, help="nlist for IVF index", default=250),
@@ -76,7 +76,7 @@ class EnVectorIVFFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict):
7676
@cli.command(name="envectorivfflat")
7777
@click_parameter_decorators_from_typed_dict(EnVectorIVFFlatIndexTypedDict)
7878
def EnVectorIVFFlat(**parameters: Unpack[EnVectorIVFFlatIndexTypedDict]):
79-
from .config import IVFFlatIndexConfig, EnVectorConfig
79+
from .config import EnVectorConfig, IVFFlatIndexConfig
8080

8181
run(
8282
db=DBTYPE,
@@ -87,7 +87,7 @@ def EnVectorIVFFlat(**parameters: Unpack[EnVectorIVFFlatIndexTypedDict]):
8787
index_params={"nlist": parameters["nlist"], "nprobe": parameters["nprobe"]},
8888
),
8989
db_case_config=IVFFlatIndexConfig(
90-
nlist=parameters["nlist"],
90+
nlist=parameters["nlist"],
9191
nprobe=parameters["nprobe"],
9292
train_centroids=parameters["train_centroids"],
9393
centroids_path=parameters["centroids_path"],

vectordb_bench/backend/clients/envector/config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pydantic import BaseModel, SecretStr
22

3-
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType, SQType
3+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
44

55

66
class EnVectorConfig(DBConfig):
@@ -67,10 +67,10 @@ class IVFFlatIndexConfig(EnVectorIndexConfig, DBCaseConfig):
6767
nlist: int = 0
6868
nprobe: int = 0
6969
eval_mode: str = "mm"
70-
train_centroids: bool = False # whether to train centroids before inserting data
70+
train_centroids: bool = False # whether to train centroids before inserting data
7171
centroids_path: str | None = None # path to centroids file
72-
is_vct: bool = False # whether use VCT index
73-
vct_path: str | None = None # path to VCT index file
72+
is_vct: bool = False # whether use VCT index
73+
vct_path: str | None = None # path to VCT index file
7474

7575
def index_param(self) -> dict:
7676
return {

vectordb_bench/backend/clients/envector/envector.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
11
"""Wrapper around the EnVector vector database over VectorDB"""
22

3-
from typing import Any, Dict
4-
53
import logging
64
import os
75
from collections.abc import Iterable
86
from contextlib import contextmanager
9-
import pickle
10-
11-
import numpy as np
7+
from typing import Any
128

139
import es2
10+
import numpy as np
1411

1512
from vectordb_bench.backend.filter import Filter, FilterOp
1613

1714
from ..api import VectorDB
1815
from .config import EnVectorIndexConfig
1916

20-
2117
log = logging.getLogger(__name__)
2218

2319

@@ -45,8 +41,8 @@ def __init__(
4541
self.case_config = db_case_config
4642
self.collection_name = collection_name
4743

48-
self.batch_size = 128 * 32 # default batch size for insertions, can be modified for IVF_FLAT
49-
44+
self.batch_size = 128 * 32 # default batch size for insertions, can be modified for IVF_FLAT
45+
5046
self._primary_field = "pk"
5147
self._scalar_id_field = "id"
5248
self._scalar_label_field = "label"
@@ -57,23 +53,23 @@ def __init__(
5753
self.col: es2.Index | None = None
5854

5955
self.is_vct: bool = False
60-
self.vct_params: Dict[str, Any] = {}
61-
kwargs: Dict[str, Any] = {}
62-
56+
self.vct_params: dict[str, Any] = {}
57+
kwargs: dict[str, Any] = {}
58+
6359
es2.init(
64-
address=self.db_config.get("uri"),
65-
key_path=self.db_config.get("key_path"),
60+
address=self.db_config.get("uri"),
61+
key_path=self.db_config.get("key_path"),
6662
key_id=self.db_config.get("key_id"),
6763
eval_mode=self.case_config.eval_mode,
6864
)
6965
if drop_old:
70-
log.info(f"{self.name} client drop_old index: {self.collection_name}")
71-
if self.collection_name in es2.get_index_list():
66+
log.info(f"{self.name} client drop_old index: {self.collection_name}")
67+
if self.collection_name in es2.get_index_list():
7268
es2.drop_index(self.collection_name)
73-
69+
7470
# Create the collection
7571
log.info(f"{self.name} create index: {self.collection_name}")
76-
72+
7773
if self.collection_name in es2.get_index_list():
7874
log.info(f"{self.name} index {self.collection_name} already exists, skip creating")
7975
self.is_vct = self.case_config.index_param().get("is_vct", False)
@@ -83,21 +79,21 @@ def __init__(
8379
index_param = self.case_config.index_param().get("params", {})
8480
index_type = index_param.get("index_type", "FLAT")
8581
train_centroids = self.case_config.index_param().get("train_centroids", False)
86-
82+
8783
if index_type == "IVF_FLAT" and train_centroids:
88-
84+
8985
centroid_path = self.case_config.index_param().get("centroids_path", None)
9086
self.is_vct = self.case_config.index_param().get("is_vct", False)
9187
log.debug(f"IS_VCT: {self.is_vct}")
92-
88+
9389
if centroid_path is not None:
9490
if not os.path.exists(centroid_path):
9591
raise FileNotFoundError(f"Centroid file {centroid_path} not found for IVF_FLAT index training.")
96-
92+
9793
# load trained centroids from file
9894
log.debug(f"Centroids: {centroid_path}")
9995
centroids = np.load(centroid_path)
100-
log.info(f"{self.name} loaded centroids from {centroid_path} for IVF_FLAT index training.")
96+
log.info(f"{self.name} loaded centroids from {centroid_path} for IVF_FLAT index training.")
10197

10298
# set centroids for index creation
10399
index_param["centroids"] = centroids.tolist()
@@ -190,7 +186,7 @@ def insert_embeddings(
190186
# use the first insert_embeddings to init collection
191187
assert self.col is not None
192188
assert len(embeddings) == len(metadata)
193-
189+
194190
log.debug(f"IS_VCT: {self.is_vct}")
195191

196192
insert_count = 0
@@ -229,7 +225,7 @@ def search_embedding(
229225
output_fields=["metadata"],
230226
search_params=self.case_config.search_param().get("search_params", {}),
231227
)
232-
228+
233229
else:
234230
# Perform the search.
235231
res = self.col.search(
@@ -249,9 +245,8 @@ def search_embedding(
249245
log.debug(f"Search results: {res[0][:1]}") # Log first 1 results for debugging
250246
if len(res) > 0 and len(res[0]) > 0:
251247
return [int(result["metadata"]) for result in res[0] if "metadata" in result]
252-
else:
253-
log.warning(f"Unexpected result structure: {res}")
254-
return []
248+
log.warning(f"Unexpected result structure: {res}")
249+
return []
255250

256251
except Exception as e:
257252
log.error(f"Search failed: {e}")

vectordb_bench/log_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import logging
2+
import os
23
from logging import config
34
from pathlib import Path
4-
import os
55

66

77
def init(log_level: str):
88
os.environ["TQDM_DISABLE"] = "1"
9-
9+
1010
# Create logs directory if it doesn't exist
1111
log_dir = Path("logs")
1212
log_dir.mkdir(exist_ok=True)

0 commit comments

Comments
 (0)