44import os
55from collections .abc import Iterable
66from contextlib import contextmanager
7+ from pathlib import Path
78from typing import Any
89
910import es2
@@ -54,7 +55,6 @@ def __init__(
5455
5556 self .is_vct : bool = False
5657 self .vct_params : dict [str , Any ] = {}
57- kwargs : dict [str , Any ] = {}
5858
5959 es2 .init (
6060 address = self .db_config .get ("uri" ),
@@ -70,66 +70,73 @@ def __init__(
7070 # Create the collection
7171 log .info (f"{ self .name } create index: { self .collection_name } " )
7272
73+ index_kwargs = dict (kwargs )
74+ self ._ensure_index (dim , index_kwargs )
75+
76+ es2 .disconnect ()
77+
78+ def _ensure_index (self , dim : int , index_kwargs : dict [str , Any ]):
7379 if self .collection_name in es2 .get_index_list ():
7480 log .info (f"{ self .name } index { self .collection_name } already exists, skip creating" )
7581 self .is_vct = self .case_config .index_param ().get ("is_vct" , False )
7682 log .debug (f"IS_VCT: { self .is_vct } " )
83+ return
84+ self ._create_index (dim , index_kwargs )
7785
78- else :
79- index_param = self .case_config .index_param ().get ("params" , {})
80- index_type = index_param .get ("index_type" , "FLAT" )
81- train_centroids = self .case_config .index_param ().get ("train_centroids" , False )
82-
83- if index_type == "IVF_FLAT" and train_centroids :
84-
85- centroid_path = self .case_config .index_param ().get ("centroids_path" , None )
86- self .is_vct = self .case_config .index_param ().get ("is_vct" , False )
87- log .debug (f"IS_VCT: { self .is_vct } " )
86+ def _create_index (self , dim : int , index_kwargs : dict [str , Any ]):
87+ index_param = self .case_config .index_param ().get ("params" , {})
88+ index_type = index_param .get ("index_type" , "FLAT" )
89+ train_centroids = self .case_config .index_param ().get ("train_centroids" , False )
8890
89- if centroid_path is not None :
90- if not os .path .exists (centroid_path ):
91- raise FileNotFoundError (f"Centroid file { centroid_path } not found for IVF_FLAT index training." )
91+ if index_type == "IVF_FLAT" and train_centroids :
92+ self ._configure_centroids (index_param , index_kwargs )
9293
93- # load trained centroids from file
94- log .debug (f"Centroids: { centroid_path } " )
95- centroids = np .load (centroid_path )
96- log .info (f"{ self .name } loaded centroids from { centroid_path } for IVF_FLAT index training." )
94+ if index_type == "IVF_FLAT" :
95+ self ._adjust_batch_size ()
9796
98- # set centroids for index creation
99- index_param ["centroids" ] = centroids .tolist ()
100-
101- if self .is_vct :
102- # set VCT parameters if applicable
103- vct_path = self .case_config .index_param ().get ("vct_path" , None )
104- log .debug (f"VCT: { vct_path } " )
105- index_param ["virtual_cluster" ] = True
106- kwargs ["tree_description" ] = vct_path
107- self .is_vct = True
108- log .info (f"{ self .name } VCT parameters set for IVF_FLAT index creation." )
109-
110- else :
111- raise ValueError ("Centroids path must be provided for IVF_FLAT index training." )
112-
113- # set larger batch size for IVF_FLAT insertions
114- if index_type == "IVF_FLAT" :
115- self .batch_size = int (os .environ .get ("NUM_PER_BATCH" , 500_000 ))
116- log .debug (
117- f"Set EnVector IVF_FLAT insert batch size to { self .batch_size } . "
118- f"This should be the size of dataset for better performance when IVF_FLAT."
119- )
97+ es2 .create_index (
98+ index_name = self .collection_name ,
99+ dim = dim ,
100+ key_path = self .db_config .get ("key_path" ),
101+ key_id = self .db_config .get ("key_id" ),
102+ index_params = index_param ,
103+ eval_mode = self .case_config .eval_mode ,
104+ ** index_kwargs ,
105+ )
120106
121- # create index after training centroids
122- es2 .create_index (
123- index_name = self .collection_name ,
124- dim = dim ,
125- key_path = self .db_config .get ("key_path" ),
126- key_id = self .db_config .get ("key_id" ),
127- index_params = index_param ,
128- eval_mode = self .case_config .eval_mode ,
129- ** kwargs ,
130- )
107+ def _configure_centroids (self , index_param : dict [str , Any ], index_kwargs : dict [str , Any ]):
108+ centroid_path = self .case_config .index_param ().get ("centroids_path" , None )
109+ self .is_vct = self .case_config .index_param ().get ("is_vct" , False )
110+ log .debug (f"IS_VCT: { self .is_vct } " )
131111
132- es2 .disconnect ()
112+ if centroid_path is None :
113+ raise ValueError ("Centroids path must be provided for IVF_FLAT index training." )
114+
115+ centroid_file = Path (centroid_path )
116+ if not centroid_file .exists ():
117+ msg = f"Centroid file { centroid_path } not found for IVF_FLAT index training."
118+ raise FileNotFoundError (msg )
119+
120+ log .debug (f"Centroids: { centroid_path } " )
121+ centroids = np .load (centroid_file )
122+ log .info (f"{ self .name } loaded centroids from { centroid_path } for IVF_FLAT index training." )
123+
124+ index_param ["centroids" ] = centroids .tolist ()
125+
126+ if self .is_vct :
127+ vct_path = self .case_config .index_param ().get ("vct_path" , None )
128+ log .debug (f"VCT: { vct_path } " )
129+ index_param ["virtual_cluster" ] = True
130+ index_kwargs ["tree_description" ] = vct_path
131+ self .is_vct = True
132+ log .info (f"{ self .name } VCT parameters set for IVF_FLAT index creation." )
133+
134+ def _adjust_batch_size (self ):
135+ self .batch_size = int (os .environ .get ("NUM_PER_BATCH" , "500000" ))
136+ log .debug (
137+ f"Set EnVector IVF_FLAT insert batch size to { self .batch_size } . "
138+ f"This should be the size of dataset for better performance when IVF_FLAT."
139+ )
133140
134141 @contextmanager
135142 def init (self ):
@@ -148,7 +155,7 @@ def init(self):
148155 try :
149156 self .col = es2 .Index (self .collection_name )
150157 if self .is_vct :
151- log .debug (f"VCT: { self .col .index_config .index_param .index_params [" virtual_cluster" ]} " )
158+ log .debug (f"VCT: { self .col .index_config .index_param .index_params [' virtual_cluster' ]} " )
152159 is_vct = self .case_config .index_param ().get ("is_vct" , False )
153160 assert self .is_vct == is_vct , "is_vct mismatch"
154161 vct_path = self .case_config .index_param ().get ("vct_path" , None )
@@ -243,11 +250,11 @@ def search_embedding(
243250 # Extract metadata from results
244251 # res structure: [[{id: X, score: Y, metadata: Z}, ...]]
245252 log .debug (f"Search results: { res [0 ][:1 ]} " ) # Log first 1 results for debugging
246- if len (res ) > 0 and len (res [0 ]) > 0 :
247- return [ int ( result [ "metadata" ]) for result in res [ 0 ] if "metadata" in result ]
248- log . warning ( f"Unexpected result structure: { res } " )
249- return []
253+ if not (res and len (res [0 ]) > 0 ) :
254+ log . warning ( f"Unexpected result structure: { res } " )
255+ return []
256+ return [int ( result [ "metadata" ]) for result in res [ 0 ] if "metadata" in result ]
250257
251- except Exception as e :
252- log .error ( f "Search failed: { e } " )
258+ except Exception :
259+ log .exception ( "Search failed" )
253260 return []
0 commit comments