@@ -90,38 +90,83 @@ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
9090 def init (self ) -> Generator [None , None , None ]:
9191 self .conn , self .cursor = self ._create_connection (** self .db_config )
9292
93- # index configuration may have commands defined that we should set during each client session
9493 session_options : dict [str , Any ] = self .case_config .session_param ()
9594
9695 if len (session_options ) > 0 :
9796 for setting_name , setting_val in session_options .items ():
98- command = sql .SQL ("SET {setting_name} " + "= {setting_val};" ).format (
99- setting_name = sql .Identifier (setting_name ),
100- setting_val = sql .Identifier (str (setting_val )),
97+ command = sql .SQL ("SET {setting_name} = {setting_val};" ).format (
98+ setting_name = sql .Identifier (setting_name ), setting_val = sql .Literal (setting_val )
10199 )
102100 log .debug (command .as_string (self .cursor ))
103101 self .cursor .execute (command )
104102 self .conn .commit ()
105103
106- self ._filtered_search = sql .Composed (
107- [
108- sql .SQL (
109- "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding " ,
110- ).format (table_name = sql .Identifier (self .table_name )),
111- sql .SQL (self .case_config .search_param ()["metric_fun_op" ]),
112- sql .SQL (" %s::vector LIMIT %s::int" ),
113- ],
114- )
104+ search_params = self .case_config .search_param ()
105+
106+ if search_params .get ("reranking" ):
107+ # Reranking-enabled queries
108+ self ._filtered_search = sql .SQL (
109+ """
110+ SELECT i.id
111+ FROM (
112+ SELECT id, embedding
113+ FROM public.{table_name}
114+ WHERE id >= %s
115+ ORDER BY embedding {metric_fun_op} %s::vector
116+ LIMIT {quantized_fetch_limit}::int
117+ ) i
118+ ORDER BY i.embedding {reranking_metric_fun_op} %s::vector
119+ LIMIT %s::int
120+ """
121+ ).format (
122+ table_name = sql .Identifier (self .table_name ),
123+ metric_fun_op = sql .SQL (search_params ["metric_fun_op" ]),
124+ reranking_metric_fun_op = sql .SQL (search_params ["reranking_metric_fun_op" ]),
125+ quantized_fetch_limit = sql .Literal (search_params ["quantized_fetch_limit" ]),
126+ )
115127
116- self ._unfiltered_search = sql .Composed (
117- [
118- sql .SQL ("SELECT id FROM public.{} ORDER BY embedding " ).format (
119- sql .Identifier (self .table_name ),
120- ),
121- sql .SQL (self .case_config .search_param ()["metric_fun_op" ]),
122- sql .SQL (" %s::vector LIMIT %s::int" ),
123- ],
124- )
128+ self ._unfiltered_search = sql .SQL (
129+ """
130+ SELECT i.id
131+ FROM (
132+ SELECT id, embedding
133+ FROM public.{table_name}
134+ ORDER BY embedding {metric_fun_op} %s::vector
135+ LIMIT {quantized_fetch_limit}::int
136+ ) i
137+ ORDER BY i.embedding {reranking_metric_fun_op} %s::vector
138+ LIMIT %s::int
139+ """
140+ ).format (
141+ table_name = sql .Identifier (self .table_name ),
142+ metric_fun_op = sql .SQL (search_params ["metric_fun_op" ]),
143+ reranking_metric_fun_op = sql .SQL (search_params ["reranking_metric_fun_op" ]),
144+ quantized_fetch_limit = sql .Literal (search_params ["quantized_fetch_limit" ]),
145+ )
146+
147+ else :
148+ self ._filtered_search = sql .Composed (
149+ [
150+ sql .SQL (
151+ "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding " ,
152+ ).format (table_name = sql .Identifier (self .table_name )),
153+ sql .SQL (search_params ["metric_fun_op" ]),
154+ sql .SQL (" %s::vector LIMIT %s::int" ),
155+ ]
156+ )
157+
158+ self ._unfiltered_search = sql .Composed (
159+ [
160+ sql .SQL ("SELECT id FROM public.{table_name} ORDER BY embedding " ).format (
161+ table_name = sql .Identifier (self .table_name )
162+ ),
163+ sql .SQL (search_params ["metric_fun_op" ]),
164+ sql .SQL (" %s::vector LIMIT %s::int" ),
165+ ]
166+ )
167+
168+ log .debug (f"Unfiltered search query={ self ._unfiltered_search .as_string (self .conn )} " )
169+ log .debug (f"Filtered search query={ self ._filtered_search .as_string (self .conn )} " )
125170
126171 try :
127172 yield
@@ -234,7 +279,7 @@ def _create_index(self):
234279 options .append (
235280 sql .SQL ("{option_name} = {val}" ).format (
236281 option_name = sql .Identifier (option_name ),
237- val = sql .Identifier ( str ( option_val ) ),
282+ val = sql .Literal ( option_val ),
238283 ),
239284 )
240285
@@ -314,16 +359,39 @@ def search_embedding(
314359 assert self .conn is not None , "Connection is not initialized"
315360 assert self .cursor is not None , "Cursor is not initialized"
316361
362+ search_params = self .case_config .search_param ()
363+ is_reranking = search_params .get ("reranking" , False )
364+
317365 q = np .asarray (query )
318366 if filters :
319367 gt = filters .get ("id" )
368+ if is_reranking :
369+ result = self .cursor .execute (
370+ self ._filtered_search ,
371+ (gt , q , q , k ),
372+ prepare = True ,
373+ binary = True ,
374+ )
375+ else :
376+ result = self .cursor .execute (
377+ self ._filtered_search ,
378+ (gt , q , k ),
379+ prepare = True ,
380+ binary = True ,
381+ )
382+ elif is_reranking :
320383 result = self .cursor .execute (
321- self ._filtered_search ,
322- (gt , q , k ),
384+ self ._unfiltered_search ,
385+ (q , q , k ),
323386 prepare = True ,
324387 binary = True ,
325388 )
326389 else :
327- result = self .cursor .execute (self ._unfiltered_search , (q , k ), prepare = True , binary = True )
390+ result = self .cursor .execute (
391+ self ._unfiltered_search ,
392+ (q , k ),
393+ prepare = True ,
394+ binary = True ,
395+ )
328396
329397 return [int (i [0 ]) for i in result .fetchall ()]
0 commit comments