Skip to content

Commit 7aaa047

Browse files
fix envector lint (#2)
1 parent 9ad0a37 commit 7aaa047

4 files changed

Lines changed: 90 additions & 88 deletions

File tree

vectordb_bench/backend/clients/envector/cli.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ class EnVectorTypedDict(TypedDict):
2323
str,
2424
click.option("--eval-mode", help="Evaluation mode", type=click.Choice(["mm", "rmp"]), default="mm"),
2525
]
26-
26+
2727

2828
class EnVectorFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict): ...
2929

3030

3131
@cli.command(name="envectorflat")
3232
@click_parameter_decorators_from_typed_dict(EnVectorFlatIndexTypedDict)
3333
def EnVectorFlat(**parameters: Unpack[EnVectorFlatIndexTypedDict]):
34-
from .config import FlatIndexConfig, EnVectorConfig
34+
from .config import EnVectorConfig, FlatIndexConfig
3535

3636
run(
3737
db=DBTYPE,
@@ -46,7 +46,7 @@ def EnVectorFlat(**parameters: Unpack[EnVectorFlatIndexTypedDict]):
4646
)
4747

4848

49-
class EnVectorIVFFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict):
49+
class EnVectorIVFFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict):
5050
nlist: Annotated[
5151
int,
5252
click.option("--nlist", type=int, help="nlist for IVF index", default=250),
@@ -76,7 +76,7 @@ class EnVectorIVFFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict):
7676
@cli.command(name="envectorivfflat")
7777
@click_parameter_decorators_from_typed_dict(EnVectorIVFFlatIndexTypedDict)
7878
def EnVectorIVFFlat(**parameters: Unpack[EnVectorIVFFlatIndexTypedDict]):
79-
from .config import IVFFlatIndexConfig, EnVectorConfig
79+
from .config import EnVectorConfig, IVFFlatIndexConfig
8080

8181
run(
8282
db=DBTYPE,
@@ -87,7 +87,7 @@ def EnVectorIVFFlat(**parameters: Unpack[EnVectorIVFFlatIndexTypedDict]):
8787
index_params={"nlist": parameters["nlist"], "nprobe": parameters["nprobe"]},
8888
),
8989
db_case_config=IVFFlatIndexConfig(
90-
nlist=parameters["nlist"],
90+
nlist=parameters["nlist"],
9191
nprobe=parameters["nprobe"],
9292
train_centroids=parameters["train_centroids"],
9393
centroids_path=parameters["centroids_path"],

vectordb_bench/backend/clients/envector/config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pydantic import BaseModel, SecretStr
22

3-
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType, SQType
3+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
44

55

66
class EnVectorConfig(DBConfig):
@@ -67,10 +67,10 @@ class IVFFlatIndexConfig(EnVectorIndexConfig, DBCaseConfig):
6767
nlist: int = 0
6868
nprobe: int = 0
6969
eval_mode: str = "mm"
70-
train_centroids: bool = False # whether to train centroids before inserting data
70+
train_centroids: bool = False # whether to train centroids before inserting data
7171
centroids_path: str | None = None # path to centroids file
72-
is_vct: bool = False # whether use VCT index
73-
vct_path: str | None = None # path to VCT index file
72+
is_vct: bool = False # whether use VCT index
73+
vct_path: str | None = None # path to VCT index file
7474

7575
def index_param(self) -> dict:
7676
return {

vectordb_bench/backend/clients/envector/envector.py

Lines changed: 79 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,20 @@
11
"""Wrapper around the EnVector vector database over VectorDB"""
22

3-
from typing import Any, Dict
4-
53
import logging
64
import os
75
from collections.abc import Iterable
86
from contextlib import contextmanager
9-
import pickle
10-
11-
import numpy as np
7+
from pathlib import Path
8+
from typing import Any
129

1310
import es2
11+
import numpy as np
1412

1513
from vectordb_bench.backend.filter import Filter, FilterOp
1614

1715
from ..api import VectorDB
1816
from .config import EnVectorIndexConfig
1917

20-
2118
log = logging.getLogger(__name__)
2219

2320

@@ -45,8 +42,8 @@ def __init__(
4542
self.case_config = db_case_config
4643
self.collection_name = collection_name
4744

48-
self.batch_size = 128 * 32 # default batch size for insertions, can be modified for IVF_FLAT
49-
45+
self.batch_size = 128 * 32 # default batch size for insertions, can be modified for IVF_FLAT
46+
5047
self._primary_field = "pk"
5148
self._scalar_id_field = "id"
5249
self._scalar_label_field = "label"
@@ -57,83 +54,89 @@ def __init__(
5754
self.col: es2.Index | None = None
5855

5956
self.is_vct: bool = False
60-
self.vct_params: Dict[str, Any] = {}
61-
kwargs: Dict[str, Any] = {}
62-
57+
self.vct_params: dict[str, Any] = {}
58+
6359
es2.init(
64-
address=self.db_config.get("uri"),
65-
key_path=self.db_config.get("key_path"),
60+
address=self.db_config.get("uri"),
61+
key_path=self.db_config.get("key_path"),
6662
key_id=self.db_config.get("key_id"),
6763
eval_mode=self.case_config.eval_mode,
6864
)
6965
if drop_old:
70-
log.info(f"{self.name} client drop_old index: {self.collection_name}")
71-
if self.collection_name in es2.get_index_list():
66+
log.info(f"{self.name} client drop_old index: {self.collection_name}")
67+
if self.collection_name in es2.get_index_list():
7268
es2.drop_index(self.collection_name)
73-
69+
7470
# Create the collection
7571
log.info(f"{self.name} create index: {self.collection_name}")
76-
72+
73+
index_kwargs = dict(kwargs)
74+
self._ensure_index(dim, index_kwargs)
75+
76+
es2.disconnect()
77+
78+
def _ensure_index(self, dim: int, index_kwargs: dict[str, Any]):
7779
if self.collection_name in es2.get_index_list():
7880
log.info(f"{self.name} index {self.collection_name} already exists, skip creating")
7981
self.is_vct = self.case_config.index_param().get("is_vct", False)
8082
log.debug(f"IS_VCT: {self.is_vct}")
83+
return
84+
self._create_index(dim, index_kwargs)
8185

82-
else:
83-
index_param = self.case_config.index_param().get("params", {})
84-
index_type = index_param.get("index_type", "FLAT")
85-
train_centroids = self.case_config.index_param().get("train_centroids", False)
86-
87-
if index_type == "IVF_FLAT" and train_centroids:
88-
89-
centroid_path = self.case_config.index_param().get("centroids_path", None)
90-
self.is_vct = self.case_config.index_param().get("is_vct", False)
91-
log.debug(f"IS_VCT: {self.is_vct}")
92-
93-
if centroid_path is not None:
94-
if not os.path.exists(centroid_path):
95-
raise FileNotFoundError(f"Centroid file {centroid_path} not found for IVF_FLAT index training.")
96-
97-
# load trained centroids from file
98-
log.debug(f"Centroids: {centroid_path}")
99-
centroids = np.load(centroid_path)
100-
log.info(f"{self.name} loaded centroids from {centroid_path} for IVF_FLAT index training.")
101-
102-
# set centroids for index creation
103-
index_param["centroids"] = centroids.tolist()
104-
105-
if self.is_vct:
106-
# set VCT parameters if applicable
107-
vct_path = self.case_config.index_param().get("vct_path", None)
108-
log.debug(f"VCT: {vct_path}")
109-
index_param["virtual_cluster"] = True
110-
kwargs["tree_description"] = vct_path
111-
self.is_vct = True
112-
log.info(f"{self.name} VCT parameters set for IVF_FLAT index creation.")
86+
def _create_index(self, dim: int, index_kwargs: dict[str, Any]):
87+
index_param = self.case_config.index_param().get("params", {})
88+
index_type = index_param.get("index_type", "FLAT")
89+
train_centroids = self.case_config.index_param().get("train_centroids", False)
11390

114-
else:
115-
raise ValueError("Centroids path must be provided for IVF_FLAT index training.")
116-
117-
# set larger batch size for IVF_FLAT insertions
118-
if index_type == "IVF_FLAT":
119-
self.batch_size = int(os.environ.get("NUM_PER_BATCH", 500_000))
120-
log.debug(
121-
f"Set EnVector IVF_FLAT insert batch size to {self.batch_size}. "
122-
f"This should be the size of dataset for better performance when IVF_FLAT."
123-
)
91+
if index_type == "IVF_FLAT" and train_centroids:
92+
self._configure_centroids(index_param, index_kwargs)
12493

125-
# create index after training centroids
126-
es2.create_index(
127-
index_name=self.collection_name,
128-
dim=dim,
129-
key_path=self.db_config.get("key_path"),
130-
key_id=self.db_config.get("key_id"),
131-
index_params=index_param,
132-
eval_mode=self.case_config.eval_mode,
133-
**kwargs,
134-
)
94+
if index_type == "IVF_FLAT":
95+
self._adjust_batch_size()
13596

136-
es2.disconnect()
97+
es2.create_index(
98+
index_name=self.collection_name,
99+
dim=dim,
100+
key_path=self.db_config.get("key_path"),
101+
key_id=self.db_config.get("key_id"),
102+
index_params=index_param,
103+
eval_mode=self.case_config.eval_mode,
104+
**index_kwargs,
105+
)
106+
107+
def _configure_centroids(self, index_param: dict[str, Any], index_kwargs: dict[str, Any]):
108+
centroid_path = self.case_config.index_param().get("centroids_path", None)
109+
self.is_vct = self.case_config.index_param().get("is_vct", False)
110+
log.debug(f"IS_VCT: {self.is_vct}")
111+
112+
if centroid_path is None:
113+
raise ValueError("Centroids path must be provided for IVF_FLAT index training.")
114+
115+
centroid_file = Path(centroid_path)
116+
if not centroid_file.exists():
117+
msg = f"Centroid file {centroid_path} not found for IVF_FLAT index training."
118+
raise FileNotFoundError(msg)
119+
120+
log.debug(f"Centroids: {centroid_path}")
121+
centroids = np.load(centroid_file)
122+
log.info(f"{self.name} loaded centroids from {centroid_path} for IVF_FLAT index training.")
123+
124+
index_param["centroids"] = centroids.tolist()
125+
126+
if self.is_vct:
127+
vct_path = self.case_config.index_param().get("vct_path", None)
128+
log.debug(f"VCT: {vct_path}")
129+
index_param["virtual_cluster"] = True
130+
index_kwargs["tree_description"] = vct_path
131+
self.is_vct = True
132+
log.info(f"{self.name} VCT parameters set for IVF_FLAT index creation.")
133+
134+
def _adjust_batch_size(self):
135+
self.batch_size = int(os.environ.get("NUM_PER_BATCH", "500000"))
136+
log.debug(
137+
f"Set EnVector IVF_FLAT insert batch size to {self.batch_size}. "
138+
f"This should be the size of dataset for better performance when IVF_FLAT."
139+
)
137140

138141
@contextmanager
139142
def init(self):
@@ -152,7 +155,7 @@ def init(self):
152155
try:
153156
self.col = es2.Index(self.collection_name)
154157
if self.is_vct:
155-
log.debug(f"VCT: {self.col.index_config.index_param.index_params["virtual_cluster"]}")
158+
log.debug(f"VCT: {self.col.index_config.index_param.index_params['virtual_cluster']}")
156159
is_vct = self.case_config.index_param().get("is_vct", False)
157160
assert self.is_vct == is_vct, "is_vct mismatch"
158161
vct_path = self.case_config.index_param().get("vct_path", None)
@@ -190,7 +193,7 @@ def insert_embeddings(
190193
# use the first insert_embeddings to init collection
191194
assert self.col is not None
192195
assert len(embeddings) == len(metadata)
193-
196+
194197
log.debug(f"IS_VCT: {self.is_vct}")
195198

196199
insert_count = 0
@@ -229,7 +232,7 @@ def search_embedding(
229232
output_fields=["metadata"],
230233
search_params=self.case_config.search_param().get("search_params", {}),
231234
)
232-
235+
233236
else:
234237
# Perform the search.
235238
res = self.col.search(
@@ -247,12 +250,11 @@ def search_embedding(
247250
# Extract metadata from results
248251
# res structure: [[{id: X, score: Y, metadata: Z}, ...]]
249252
log.debug(f"Search results: {res[0][:1]}") # Log first 1 results for debugging
250-
if len(res) > 0 and len(res[0]) > 0:
251-
return [int(result["metadata"]) for result in res[0] if "metadata" in result]
252-
else:
253+
if not (res and len(res[0]) > 0):
253254
log.warning(f"Unexpected result structure: {res}")
254255
return []
256+
return [int(result["metadata"]) for result in res[0] if "metadata" in result]
255257

256-
except Exception as e:
257-
log.error(f"Search failed: {e}")
258+
except Exception:
259+
log.exception("Search failed")
258260
return []

vectordb_bench/log_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import logging
2+
import os
23
from logging import config
34
from pathlib import Path
4-
import os
55

66

77
def init(log_level: str):
88
os.environ["TQDM_DISABLE"] = "1"
9-
9+
1010
# Create logs directory if it doesn't exist
1111
log_dir = Path("logs")
1212
log_dir.mkdir(exist_ok=True)

0 commit comments

Comments
 (0)