11"""Wrapper around the EnVector vector database over VectorDB"""
22
3- from typing import Any , Dict
4-
53import logging
64import os
75from collections .abc import Iterable
86from contextlib import contextmanager
9- import pickle
10-
11- import numpy as np
7+ from typing import Any
128
139import es2
10+ import numpy as np
1411
1512from vectordb_bench .backend .filter import Filter , FilterOp
1613
1714from ..api import VectorDB
1815from .config import EnVectorIndexConfig
1916
20-
2117log = logging .getLogger (__name__ )
2218
2319
@@ -45,8 +41,8 @@ def __init__(
4541 self .case_config = db_case_config
4642 self .collection_name = collection_name
4743
48- self .batch_size = 128 * 32 # default batch size for insertions, can be modified for IVF_FLAT
49-
44+ self .batch_size = 128 * 32 # default batch size for insertions, can be modified for IVF_FLAT
45+
5046 self ._primary_field = "pk"
5147 self ._scalar_id_field = "id"
5248 self ._scalar_label_field = "label"
@@ -57,23 +53,23 @@ def __init__(
5753 self .col : es2 .Index | None = None
5854
5955 self .is_vct : bool = False
60- self .vct_params : Dict [str , Any ] = {}
61- kwargs : Dict [str , Any ] = {}
62-
56+ self .vct_params : dict [str , Any ] = {}
57+ kwargs : dict [str , Any ] = {}
58+
6359 es2 .init (
64- address = self .db_config .get ("uri" ),
65- key_path = self .db_config .get ("key_path" ),
60+ address = self .db_config .get ("uri" ),
61+ key_path = self .db_config .get ("key_path" ),
6662 key_id = self .db_config .get ("key_id" ),
6763 eval_mode = self .case_config .eval_mode ,
6864 )
6965 if drop_old :
70- log .info (f"{ self .name } client drop_old index: { self .collection_name } " )
71- if self .collection_name in es2 .get_index_list ():
66+ log .info (f"{ self .name } client drop_old index: { self .collection_name } " )
67+ if self .collection_name in es2 .get_index_list ():
7268 es2 .drop_index (self .collection_name )
73-
69+
7470 # Create the collection
7571 log .info (f"{ self .name } create index: { self .collection_name } " )
76-
72+
7773 if self .collection_name in es2 .get_index_list ():
7874 log .info (f"{ self .name } index { self .collection_name } already exists, skip creating" )
7975 self .is_vct = self .case_config .index_param ().get ("is_vct" , False )
@@ -83,21 +79,21 @@ def __init__(
8379 index_param = self .case_config .index_param ().get ("params" , {})
8480 index_type = index_param .get ("index_type" , "FLAT" )
8581 train_centroids = self .case_config .index_param ().get ("train_centroids" , False )
86-
82+
8783 if index_type == "IVF_FLAT" and train_centroids :
88-
84+
8985 centroid_path = self .case_config .index_param ().get ("centroids_path" , None )
9086 self .is_vct = self .case_config .index_param ().get ("is_vct" , False )
9187 log .debug (f"IS_VCT: { self .is_vct } " )
92-
88+
9389 if centroid_path is not None :
9490 if not os .path .exists (centroid_path ):
9591 raise FileNotFoundError (f"Centroid file { centroid_path } not found for IVF_FLAT index training." )
96-
92+
9793 # load trained centroids from file
9894 log .debug (f"Centroids: { centroid_path } " )
9995 centroids = np .load (centroid_path )
100- log .info (f"{ self .name } loaded centroids from { centroid_path } for IVF_FLAT index training." )
96+ log .info (f"{ self .name } loaded centroids from { centroid_path } for IVF_FLAT index training." )
10197
10298 # set centroids for index creation
10399 index_param ["centroids" ] = centroids .tolist ()
@@ -190,7 +186,7 @@ def insert_embeddings(
190186 # use the first insert_embeddings to init collection
191187 assert self .col is not None
192188 assert len (embeddings ) == len (metadata )
193-
189+
194190 log .debug (f"IS_VCT: { self .is_vct } " )
195191
196192 insert_count = 0
@@ -229,7 +225,7 @@ def search_embedding(
229225 output_fields = ["metadata" ],
230226 search_params = self .case_config .search_param ().get ("search_params" , {}),
231227 )
232-
228+
233229 else :
234230 # Perform the search.
235231 res = self .col .search (
@@ -249,9 +245,8 @@ def search_embedding(
249245 log .debug (f"Search results: { res [0 ][:1 ]} " ) # Log first 1 results for debugging
250246 if len (res ) > 0 and len (res [0 ]) > 0 :
251247 return [int (result ["metadata" ]) for result in res [0 ] if "metadata" in result ]
252- else :
253- log .warning (f"Unexpected result structure: { res } " )
254- return []
248+ log .warning (f"Unexpected result structure: { res } " )
249+ return []
255250
256251 except Exception as e :
257252 log .error (f"Search failed: { e } " )
0 commit comments