11"""Wrapper around the EnVector vector database over VectorDB"""
22
3- from typing import Any , Dict
4-
53import logging
64import os
75from collections .abc import Iterable
86from contextlib import contextmanager
9- import pickle
10-
11- import numpy as np
7+ from pathlib import Path
8+ from typing import Any
129
1310import es2
11+ import numpy as np
1412
1513from vectordb_bench .backend .filter import Filter , FilterOp
1614
1715from ..api import VectorDB
1816from .config import EnVectorIndexConfig
1917
20-
2118log = logging .getLogger (__name__ )
2219
2320
@@ -45,8 +42,8 @@ def __init__(
4542 self .case_config = db_case_config
4643 self .collection_name = collection_name
4744
48- self .batch_size = 128 * 32 # default batch size for insertions, can be modified for IVF_FLAT
49-
45+ self .batch_size = 128 * 32 # default batch size for insertions, can be modified for IVF_FLAT
46+
5047 self ._primary_field = "pk"
5148 self ._scalar_id_field = "id"
5249 self ._scalar_label_field = "label"
@@ -57,83 +54,89 @@ def __init__(
5754 self .col : es2 .Index | None = None
5855
5956 self .is_vct : bool = False
60- self .vct_params : Dict [str , Any ] = {}
61- kwargs : Dict [str , Any ] = {}
62-
57+ self .vct_params : dict [str , Any ] = {}
58+
6359 es2 .init (
64- address = self .db_config .get ("uri" ),
65- key_path = self .db_config .get ("key_path" ),
60+ address = self .db_config .get ("uri" ),
61+ key_path = self .db_config .get ("key_path" ),
6662 key_id = self .db_config .get ("key_id" ),
6763 eval_mode = self .case_config .eval_mode ,
6864 )
6965 if drop_old :
70- log .info (f"{ self .name } client drop_old index: { self .collection_name } " )
71- if self .collection_name in es2 .get_index_list ():
66+ log .info (f"{ self .name } client drop_old index: { self .collection_name } " )
67+ if self .collection_name in es2 .get_index_list ():
7268 es2 .drop_index (self .collection_name )
73-
69+
7470 # Create the collection
7571 log .info (f"{ self .name } create index: { self .collection_name } " )
76-
72+
73+ index_kwargs = dict (kwargs )
74+ self ._ensure_index (dim , index_kwargs )
75+
76+ es2 .disconnect ()
77+
78+ def _ensure_index (self , dim : int , index_kwargs : dict [str , Any ]):
7779 if self .collection_name in es2 .get_index_list ():
7880 log .info (f"{ self .name } index { self .collection_name } already exists, skip creating" )
7981 self .is_vct = self .case_config .index_param ().get ("is_vct" , False )
8082 log .debug (f"IS_VCT: { self .is_vct } " )
83+ return
84+ self ._create_index (dim , index_kwargs )
8185
82- else :
83- index_param = self .case_config .index_param ().get ("params" , {})
84- index_type = index_param .get ("index_type" , "FLAT" )
85- train_centroids = self .case_config .index_param ().get ("train_centroids" , False )
86-
87- if index_type == "IVF_FLAT" and train_centroids :
88-
89- centroid_path = self .case_config .index_param ().get ("centroids_path" , None )
90- self .is_vct = self .case_config .index_param ().get ("is_vct" , False )
91- log .debug (f"IS_VCT: { self .is_vct } " )
92-
93- if centroid_path is not None :
94- if not os .path .exists (centroid_path ):
95- raise FileNotFoundError (f"Centroid file { centroid_path } not found for IVF_FLAT index training." )
96-
97- # load trained centroids from file
98- log .debug (f"Centroids: { centroid_path } " )
99- centroids = np .load (centroid_path )
100- log .info (f"{ self .name } loaded centroids from { centroid_path } for IVF_FLAT index training." )
101-
102- # set centroids for index creation
103- index_param ["centroids" ] = centroids .tolist ()
104-
105- if self .is_vct :
106- # set VCT parameters if applicable
107- vct_path = self .case_config .index_param ().get ("vct_path" , None )
108- log .debug (f"VCT: { vct_path } " )
109- index_param ["virtual_cluster" ] = True
110- kwargs ["tree_description" ] = vct_path
111- self .is_vct = True
112- log .info (f"{ self .name } VCT parameters set for IVF_FLAT index creation." )
86+ def _create_index (self , dim : int , index_kwargs : dict [str , Any ]):
87+ index_param = self .case_config .index_param ().get ("params" , {})
88+ index_type = index_param .get ("index_type" , "FLAT" )
89+ train_centroids = self .case_config .index_param ().get ("train_centroids" , False )
11390
114- else :
115- raise ValueError ("Centroids path must be provided for IVF_FLAT index training." )
116-
117- # set larger batch size for IVF_FLAT insertions
118- if index_type == "IVF_FLAT" :
119- self .batch_size = int (os .environ .get ("NUM_PER_BATCH" , 500_000 ))
120- log .debug (
121- f"Set EnVector IVF_FLAT insert batch size to { self .batch_size } . "
122- f"This should be the size of dataset for better performance when IVF_FLAT."
123- )
91+ if index_type == "IVF_FLAT" and train_centroids :
92+ self ._configure_centroids (index_param , index_kwargs )
12493
125- # create index after training centroids
126- es2 .create_index (
127- index_name = self .collection_name ,
128- dim = dim ,
129- key_path = self .db_config .get ("key_path" ),
130- key_id = self .db_config .get ("key_id" ),
131- index_params = index_param ,
132- eval_mode = self .case_config .eval_mode ,
133- ** kwargs ,
134- )
94+ if index_type == "IVF_FLAT" :
95+ self ._adjust_batch_size ()
13596
136- es2 .disconnect ()
97+ es2 .create_index (
98+ index_name = self .collection_name ,
99+ dim = dim ,
100+ key_path = self .db_config .get ("key_path" ),
101+ key_id = self .db_config .get ("key_id" ),
102+ index_params = index_param ,
103+ eval_mode = self .case_config .eval_mode ,
104+ ** index_kwargs ,
105+ )
106+
107+ def _configure_centroids (self , index_param : dict [str , Any ], index_kwargs : dict [str , Any ]):
108+ centroid_path = self .case_config .index_param ().get ("centroids_path" , None )
109+ self .is_vct = self .case_config .index_param ().get ("is_vct" , False )
110+ log .debug (f"IS_VCT: { self .is_vct } " )
111+
112+ if centroid_path is None :
113+ raise ValueError ("Centroids path must be provided for IVF_FLAT index training." )
114+
115+ centroid_file = Path (centroid_path )
116+ if not centroid_file .exists ():
117+ msg = f"Centroid file { centroid_path } not found for IVF_FLAT index training."
118+ raise FileNotFoundError (msg )
119+
120+ log .debug (f"Centroids: { centroid_path } " )
121+ centroids = np .load (centroid_file )
122+ log .info (f"{ self .name } loaded centroids from { centroid_path } for IVF_FLAT index training." )
123+
124+ index_param ["centroids" ] = centroids .tolist ()
125+
126+ if self .is_vct :
127+ vct_path = self .case_config .index_param ().get ("vct_path" , None )
128+ log .debug (f"VCT: { vct_path } " )
129+ index_param ["virtual_cluster" ] = True
130+ index_kwargs ["tree_description" ] = vct_path
131+ self .is_vct = True
132+ log .info (f"{ self .name } VCT parameters set for IVF_FLAT index creation." )
133+
134+ def _adjust_batch_size (self ):
135+ self .batch_size = int (os .environ .get ("NUM_PER_BATCH" , "500000" ))
136+ log .debug (
137+ f"Set EnVector IVF_FLAT insert batch size to { self .batch_size } . "
138+ f"This should be the size of dataset for better performance when IVF_FLAT."
139+ )
137140
138141 @contextmanager
139142 def init (self ):
@@ -152,7 +155,7 @@ def init(self):
152155 try :
153156 self .col = es2 .Index (self .collection_name )
154157 if self .is_vct :
155- log .debug (f"VCT: { self .col .index_config .index_param .index_params [" virtual_cluster" ]} " )
158+ log .debug (f"VCT: { self .col .index_config .index_param .index_params [' virtual_cluster' ]} " )
156159 is_vct = self .case_config .index_param ().get ("is_vct" , False )
157160 assert self .is_vct == is_vct , "is_vct mismatch"
158161 vct_path = self .case_config .index_param ().get ("vct_path" , None )
@@ -190,7 +193,7 @@ def insert_embeddings(
190193 # use the first insert_embeddings to init collection
191194 assert self .col is not None
192195 assert len (embeddings ) == len (metadata )
193-
196+
194197 log .debug (f"IS_VCT: { self .is_vct } " )
195198
196199 insert_count = 0
@@ -229,7 +232,7 @@ def search_embedding(
229232 output_fields = ["metadata" ],
230233 search_params = self .case_config .search_param ().get ("search_params" , {}),
231234 )
232-
235+
233236 else :
234237 # Perform the search.
235238 res = self .col .search (
@@ -247,12 +250,11 @@ def search_embedding(
247250 # Extract metadata from results
248251 # res structure: [[{id: X, score: Y, metadata: Z}, ...]]
249252 log .debug (f"Search results: { res [0 ][:1 ]} " ) # Log first 1 results for debugging
250- if len (res ) > 0 and len (res [0 ]) > 0 :
251- return [int (result ["metadata" ]) for result in res [0 ] if "metadata" in result ]
252- else :
253+ if not (res and len (res [0 ]) > 0 ):
253254 log .warning (f"Unexpected result structure: { res } " )
254255 return []
256+ return [int (result ["metadata" ]) for result in res [0 ] if "metadata" in result ]
255257
256- except Exception as e :
257- log .error ( f "Search failed: { e } " )
258+ except Exception :
259+ log .exception ( "Search failed" )
258260 return []
0 commit comments