Skip to content

Commit 82d1491

Browse files
committed
fix lint
1 parent 28061e9 commit 82d1491

1 file changed

Lines changed: 65 additions & 58 deletions

File tree

vectordb_bench/backend/clients/envector/envector.py

Lines changed: 65 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
from collections.abc import Iterable
66
from contextlib import contextmanager
7+
from pathlib import Path
78
from typing import Any
89

910
import es2
@@ -54,7 +55,6 @@ def __init__(
5455

5556
self.is_vct: bool = False
5657
self.vct_params: dict[str, Any] = {}
57-
kwargs: dict[str, Any] = {}
5858

5959
es2.init(
6060
address=self.db_config.get("uri"),
@@ -70,66 +70,73 @@ def __init__(
7070
# Create the collection
7171
log.info(f"{self.name} create index: {self.collection_name}")
7272

73+
index_kwargs = dict(kwargs)
74+
self._ensure_index(dim, index_kwargs)
75+
76+
es2.disconnect()
77+
78+
def _ensure_index(self, dim: int, index_kwargs: dict[str, Any]):
7379
if self.collection_name in es2.get_index_list():
7480
log.info(f"{self.name} index {self.collection_name} already exists, skip creating")
7581
self.is_vct = self.case_config.index_param().get("is_vct", False)
7682
log.debug(f"IS_VCT: {self.is_vct}")
83+
return
84+
self._create_index(dim, index_kwargs)
7785

78-
else:
79-
index_param = self.case_config.index_param().get("params", {})
80-
index_type = index_param.get("index_type", "FLAT")
81-
train_centroids = self.case_config.index_param().get("train_centroids", False)
82-
83-
if index_type == "IVF_FLAT" and train_centroids:
84-
85-
centroid_path = self.case_config.index_param().get("centroids_path", None)
86-
self.is_vct = self.case_config.index_param().get("is_vct", False)
87-
log.debug(f"IS_VCT: {self.is_vct}")
86+
def _create_index(self, dim: int, index_kwargs: dict[str, Any]):
87+
index_param = self.case_config.index_param().get("params", {})
88+
index_type = index_param.get("index_type", "FLAT")
89+
train_centroids = self.case_config.index_param().get("train_centroids", False)
8890

89-
if centroid_path is not None:
90-
if not os.path.exists(centroid_path):
91-
raise FileNotFoundError(f"Centroid file {centroid_path} not found for IVF_FLAT index training.")
91+
if index_type == "IVF_FLAT" and train_centroids:
92+
self._configure_centroids(index_param, index_kwargs)
9293

93-
# load trained centroids from file
94-
log.debug(f"Centroids: {centroid_path}")
95-
centroids = np.load(centroid_path)
96-
log.info(f"{self.name} loaded centroids from {centroid_path} for IVF_FLAT index training.")
94+
if index_type == "IVF_FLAT":
95+
self._adjust_batch_size()
9796

98-
# set centroids for index creation
99-
index_param["centroids"] = centroids.tolist()
100-
101-
if self.is_vct:
102-
# set VCT parameters if applicable
103-
vct_path = self.case_config.index_param().get("vct_path", None)
104-
log.debug(f"VCT: {vct_path}")
105-
index_param["virtual_cluster"] = True
106-
kwargs["tree_description"] = vct_path
107-
self.is_vct = True
108-
log.info(f"{self.name} VCT parameters set for IVF_FLAT index creation.")
109-
110-
else:
111-
raise ValueError("Centroids path must be provided for IVF_FLAT index training.")
112-
113-
# set larger batch size for IVF_FLAT insertions
114-
if index_type == "IVF_FLAT":
115-
self.batch_size = int(os.environ.get("NUM_PER_BATCH", 500_000))
116-
log.debug(
117-
f"Set EnVector IVF_FLAT insert batch size to {self.batch_size}. "
118-
f"This should be the size of dataset for better performance when IVF_FLAT."
119-
)
97+
es2.create_index(
98+
index_name=self.collection_name,
99+
dim=dim,
100+
key_path=self.db_config.get("key_path"),
101+
key_id=self.db_config.get("key_id"),
102+
index_params=index_param,
103+
eval_mode=self.case_config.eval_mode,
104+
**index_kwargs,
105+
)
120106

121-
# create index after training centroids
122-
es2.create_index(
123-
index_name=self.collection_name,
124-
dim=dim,
125-
key_path=self.db_config.get("key_path"),
126-
key_id=self.db_config.get("key_id"),
127-
index_params=index_param,
128-
eval_mode=self.case_config.eval_mode,
129-
**kwargs,
130-
)
107+
def _configure_centroids(self, index_param: dict[str, Any], index_kwargs: dict[str, Any]):
108+
centroid_path = self.case_config.index_param().get("centroids_path", None)
109+
self.is_vct = self.case_config.index_param().get("is_vct", False)
110+
log.debug(f"IS_VCT: {self.is_vct}")
131111

132-
es2.disconnect()
112+
if centroid_path is None:
113+
raise ValueError("Centroids path must be provided for IVF_FLAT index training.")
114+
115+
centroid_file = Path(centroid_path)
116+
if not centroid_file.exists():
117+
msg = f"Centroid file {centroid_path} not found for IVF_FLAT index training."
118+
raise FileNotFoundError(msg)
119+
120+
log.debug(f"Centroids: {centroid_path}")
121+
centroids = np.load(centroid_file)
122+
log.info(f"{self.name} loaded centroids from {centroid_path} for IVF_FLAT index training.")
123+
124+
index_param["centroids"] = centroids.tolist()
125+
126+
if self.is_vct:
127+
vct_path = self.case_config.index_param().get("vct_path", None)
128+
log.debug(f"VCT: {vct_path}")
129+
index_param["virtual_cluster"] = True
130+
index_kwargs["tree_description"] = vct_path
131+
self.is_vct = True
132+
log.info(f"{self.name} VCT parameters set for IVF_FLAT index creation.")
133+
134+
def _adjust_batch_size(self):
135+
self.batch_size = int(os.environ.get("NUM_PER_BATCH", "500000"))
136+
log.debug(
137+
f"Set EnVector IVF_FLAT insert batch size to {self.batch_size}. "
138+
f"This should be the size of dataset for better performance when IVF_FLAT."
139+
)
133140

134141
@contextmanager
135142
def init(self):
@@ -148,7 +155,7 @@ def init(self):
148155
try:
149156
self.col = es2.Index(self.collection_name)
150157
if self.is_vct:
151-
log.debug(f"VCT: {self.col.index_config.index_param.index_params["virtual_cluster"]}")
158+
log.debug(f"VCT: {self.col.index_config.index_param.index_params['virtual_cluster']}")
152159
is_vct = self.case_config.index_param().get("is_vct", False)
153160
assert self.is_vct == is_vct, "is_vct mismatch"
154161
vct_path = self.case_config.index_param().get("vct_path", None)
@@ -243,11 +250,11 @@ def search_embedding(
243250
# Extract metadata from results
244251
# res structure: [[{id: X, score: Y, metadata: Z}, ...]]
245252
log.debug(f"Search results: {res[0][:1]}") # Log first 1 results for debugging
246-
if len(res) > 0 and len(res[0]) > 0:
247-
return [int(result["metadata"]) for result in res[0] if "metadata" in result]
248-
log.warning(f"Unexpected result structure: {res}")
249-
return []
253+
if not (res and len(res[0]) > 0):
254+
log.warning(f"Unexpected result structure: {res}")
255+
return []
256+
return [int(result["metadata"]) for result in res[0] if "metadata" in result]
250257

251-
except Exception as e:
252-
log.error(f"Search failed: {e}")
258+
except Exception:
259+
log.exception("Search failed")
253260
return []

0 commit comments

Comments
 (0)