1010from pgvector .psycopg import register_vector
1111from psycopg import Connection , Cursor , sql
1212
13+ from vectordb_bench .backend .filter import Filter , FilterOp
14+
1315from ..api import VectorDB
1416from .config import PgDiskANNConfigDict , PgDiskANNIndexConfig
1517
1921class PgDiskANN (VectorDB ):
2022 """Use psycopg instructions"""
2123
24+ supported_filter_types : list [FilterOp ] = [
25+ FilterOp .NonFilter ,
26+ FilterOp .NumGE ,
27+ FilterOp .StrEqual ,
28+ ]
29+
2230 conn : psycopg .Connection [Any ] | None = None
23- coursor : psycopg .Cursor [Any ] | None = None
31+ cursor : psycopg .Cursor [Any ] | None = None
2432
25- _filtered_search : sql .Composed
26- _unfiltered_search : sql .Composed
33+ _search : sql .Composed
2734
2835 def __init__ (
2936 self ,
@@ -32,13 +39,17 @@ def __init__(
3239 db_case_config : PgDiskANNIndexConfig ,
3340 collection_name : str = "pg_diskann_collection" ,
3441 drop_old : bool = False ,
42+ with_scalar_labels : bool = False ,
3543 ** kwargs ,
3644 ):
3745 self .name = "PgDiskANN"
3846 self .db_config = db_config
3947 self .case_config = db_case_config
4048 self .table_name = collection_name
4149 self .dim = dim
50+ self .with_scalar_labels = with_scalar_labels
51+ self ._scalar_label_field = "label"
52+ self .where_clause = ""
4253
4354 self ._index_name = "pgdiskann_index"
4455 self ._primary_field = "id"
@@ -86,83 +97,58 @@ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
8697
8798 return conn , cursor
8899
89- @contextmanager
90- def init (self ) -> Generator [None , None , None ]:
91- self .conn , self .cursor = self ._create_connection (** self .db_config )
92-
93- session_options : dict [str , Any ] = self .case_config .session_param ()
94-
95- if len (session_options ) > 0 :
96- for setting_name , setting_val in session_options .items ():
97- command = sql .SQL ("SET {setting_name} = {setting_val};" ).format (
98- setting_name = sql .Identifier (setting_name ), setting_val = sql .Literal (setting_val )
99- )
100- log .debug (command .as_string (self .cursor ))
101- self .cursor .execute (command )
102- self .conn .commit ()
103-
100+ def _generate_search_query (self ) -> sql .Composed :
101+ """Generate search query with where_clause placeholder"""
104102 search_params = self .case_config .search_param ()
105103
106104 if search_params .get ("reranking" ):
107- # Reranking-enabled queries
108- self ._filtered_search = sql .SQL ("""
105+ search_query = sql .SQL ("""
109106 SELECT i.id
110107 FROM (
111108 SELECT id, embedding
112109 FROM public.{table_name}
113- WHERE id >= %s
110+ {where_clause}
114111 ORDER BY embedding {metric_fun_op} %s::vector
115112 LIMIT {quantized_fetch_limit}::int
116113 ) i
117114 ORDER BY i.embedding {reranking_metric_fun_op} %s::vector
118115 LIMIT %s::int
119- """ ).format (
116+ """ ).format (
120117 table_name = sql .Identifier (self .table_name ),
118+ where_clause = sql .SQL (self .where_clause ),
121119 metric_fun_op = sql .SQL (search_params ["metric_fun_op" ]),
122120 reranking_metric_fun_op = sql .SQL (search_params ["reranking_metric_fun_op" ]),
123121 quantized_fetch_limit = sql .Literal (search_params ["quantized_fetch_limit" ]),
124122 )
125-
126- self ._unfiltered_search = sql .SQL ("""
127- SELECT i.id
128- FROM (
129- SELECT id, embedding
130- FROM public.{table_name}
131- ORDER BY embedding {metric_fun_op} %s::vector
132- LIMIT {quantized_fetch_limit}::int
133- ) i
134- ORDER BY i.embedding {reranking_metric_fun_op} %s::vector
135- LIMIT %s::int
136- """ ).format (
137- table_name = sql .Identifier (self .table_name ),
138- metric_fun_op = sql .SQL (search_params ["metric_fun_op" ]),
139- reranking_metric_fun_op = sql .SQL (search_params ["reranking_metric_fun_op" ]),
140- quantized_fetch_limit = sql .Literal (search_params ["quantized_fetch_limit" ]),
141- )
142-
143123 else :
144- self . _filtered_search = sql .Composed (
124+ search_query = sql .Composed (
145125 [
146- sql .SQL (
147- "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding " ,
148- ).format (table_name = sql .Identifier (self .table_name )),
149- sql .SQL (search_params ["metric_fun_op" ]),
150- sql .SQL (" %s::vector LIMIT %s::int" ),
151- ]
152- )
153-
154- self ._unfiltered_search = sql .Composed (
155- [
156- sql .SQL ("SELECT id FROM public.{table_name} ORDER BY embedding " ).format (
157- table_name = sql .Identifier (self .table_name )
126+ sql .SQL ("SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding " ).format (
127+ table_name = sql .Identifier (self .table_name ),
128+ where_clause = sql .SQL (self .where_clause ),
158129 ),
159130 sql .SQL (search_params ["metric_fun_op" ]),
160131 sql .SQL (" %s::vector LIMIT %s::int" ),
161132 ]
162133 )
163134
164- log .debug (f"Unfiltered search query={ self ._unfiltered_search .as_string (self .conn )} " )
165- log .debug (f"Filtered search query={ self ._filtered_search .as_string (self .conn )} " )
135+ return search_query
136+
137+ @contextmanager
138+ def init (self ) -> Generator [None , None , None ]:
139+ self .conn , self .cursor = self ._create_connection (** self .db_config )
140+
141+ session_options : dict [str , Any ] = self .case_config .session_param ()
142+
143+ if len (session_options ) > 0 :
144+ for setting_name , setting_val in session_options .items ():
145+ command = sql .SQL ("SET {setting_name} = {setting_val};" ).format (
146+ setting_name = sql .Identifier (setting_name ),
147+ setting_val = sql .Literal (setting_val ),
148+ )
149+ log .debug (command .as_string (self .cursor ))
150+ self .cursor .execute (command )
151+ self .conn .commit ()
166152
167153 try :
168154 yield
@@ -281,12 +267,10 @@ def _create_index(self):
281267
282268 with_clause = sql .SQL ("WITH ({});" ).format (sql .SQL (", " ).join (options )) if any (options ) else sql .Composed (())
283269
284- index_create_sql = sql .SQL (
285- """
270+ index_create_sql = sql .SQL ("""
286271 CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
287272 USING {index_type} (embedding {embedding_metric})
288- """ ,
289- ).format (
273+ """ ).format (
290274 index_name = sql .Identifier (self ._index_name ),
291275 table_name = sql .Identifier (self .table_name ),
292276 index_type = sql .Identifier (index_param ["index_type" ].lower ()),
@@ -304,11 +288,36 @@ def _create_table(self, dim: int):
304288 try :
305289 log .info (f"{ self .name } client create table : { self .table_name } " )
306290
291+ if self .with_scalar_labels :
292+ self .cursor .execute (
293+ sql .SQL ("""
294+ CREATE TABLE IF NOT EXISTS public.{table_name}
295+ ({primary_field} BIGINT PRIMARY KEY, embedding vector({dim}), {label_field} VARCHAR(64));
296+ """ ).format (
297+ table_name = sql .Identifier (self .table_name ),
298+ dim = dim ,
299+ primary_field = sql .Identifier (self ._primary_field ),
300+ label_field = sql .Identifier (self ._scalar_label_field ),
301+ ),
302+ )
303+ else :
304+ self .cursor .execute (
305+ sql .SQL ("""
306+ CREATE TABLE IF NOT EXISTS public.{table_name}
307+ ({primary_field} BIGINT PRIMARY KEY, embedding vector({dim}));
308+ """ ).format (
309+ table_name = sql .Identifier (self .table_name ),
310+ dim = dim ,
311+ primary_field = sql .Identifier (self ._primary_field ),
312+ ),
313+ )
314+
307315 self .cursor .execute (
308- sql .SQL (
309- "CREATE TABLE IF NOT EXISTS public.{ table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));" ,
310- ). format ( table_name = sql . Identifier ( self . table_name ), dim = dim ) ,
316+ sql .SQL ("ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;" ). format (
317+ table_name = sql . Identifier ( self . table_name )
318+ ),
311319 )
320+
312321 self .conn .commit ()
313322 except Exception as e :
314323 log .warning (f"Failed to create pgdiskann table: { self .table_name } error: { e } " )
@@ -318,11 +327,15 @@ def insert_embeddings(
318327 self ,
319328 embeddings : list [list [float ]],
320329 metadata : list [int ],
330+ labels_data : list [str ] | None = None ,
321331 ** kwargs : Any ,
322332 ) -> tuple [int , Exception | None ]:
323333 assert self .conn is not None , "Connection is not initialized"
324334 assert self .cursor is not None , "Cursor is not initialized"
325335
336+ if self .with_scalar_labels :
337+ assert labels_data is not None , "labels_data should be provided if with_scalar_labels is set to True"
338+
326339 try :
327340 metadata_arr = np .array (metadata )
328341 embeddings_arr = np .array (embeddings )
@@ -332,9 +345,14 @@ def insert_embeddings(
332345 table_name = sql .Identifier (self .table_name ),
333346 ),
334347 ) as copy :
335- copy .set_types (["bigint" , "vector" ])
336- for i , row in enumerate (metadata_arr ):
337- copy .write_row ((row , embeddings_arr [i ]))
348+ if self .with_scalar_labels :
349+ copy .set_types (["bigint" , "vector" , "varchar" ])
350+ for i , row in enumerate (metadata_arr ):
351+ copy .write_row ((row , embeddings_arr [i ], labels_data [i ]))
352+ else :
353+ copy .set_types (["bigint" , "vector" ])
354+ for i , row in enumerate (metadata_arr ):
355+ copy .write_row ((row , embeddings_arr [i ]))
338356 self .conn .commit ()
339357
340358 if kwargs .get ("last_batch" ):
@@ -345,49 +363,39 @@ def insert_embeddings(
345363 log .warning (f"Failed to insert data into table ({ self .table_name } ), error: { e } " )
346364 return 0 , e
347365
366+ def prepare_filter (self , filters : Filter ):
367+ """Prepare filter - builds where_clause"""
368+ if filters .type == FilterOp .NonFilter :
369+ self .where_clause = ""
370+ elif filters .type == FilterOp .NumGE :
371+ self .where_clause = f"WHERE { self ._primary_field } >= { filters .int_value } "
372+ elif filters .type == FilterOp .StrEqual :
373+ self .where_clause = f"WHERE { self ._scalar_label_field } = '{ filters .label_value } '"
374+ else :
375+ msg = f"Not support Filter for PgDiskANN - { filters } "
376+ raise ValueError (msg )
377+
378+ self ._search = self ._generate_search_query ()
379+ log .debug (f"Search query={ self ._search .as_string (self .conn )} " )
380+
348381 def search_embedding (
349382 self ,
350383 query : list [float ],
351384 k : int = 100 ,
352- filters : dict | None = None ,
353385 timeout : int | None = None ,
386+ ** kwargs : Any ,
354387 ) -> list [int ]:
355388 assert self .conn is not None , "Connection is not initialized"
356389 assert self .cursor is not None , "Cursor is not initialized"
357390
358391 search_params = self .case_config .search_param ()
359- is_reranking = search_params .get ("reranking" , False )
360-
361392 q = np .asarray (query )
362- if filters :
363- gt = filters .get ("id" )
364- if is_reranking :
365- result = self .cursor .execute (
366- self ._filtered_search ,
367- (gt , q , q , k ),
368- prepare = True ,
369- binary = True ,
370- )
371- else :
372- result = self .cursor .execute (
373- self ._filtered_search ,
374- (gt , q , k ),
375- prepare = True ,
376- binary = True ,
377- )
378- elif is_reranking :
379- result = self .cursor .execute (
380- self ._unfiltered_search ,
381- (q , q , k ),
382- prepare = True ,
383- binary = True ,
384- )
385- else :
386- result = self .cursor .execute (
387- self ._unfiltered_search ,
388- (q , k ),
389- prepare = True ,
390- binary = True ,
391- )
393+
394+ result = self .cursor .execute (
395+ self ._search ,
396+ (q , q , k ) if search_params .get ("reranking" , False ) else (q , k ),
397+ prepare = True ,
398+ binary = True ,
399+ )
392400
393401 return [int (i [0 ]) for i in result .fetchall ()]
0 commit comments