1010from pgvector .psycopg import register_vector
1111from psycopg import Connection , Cursor , sql
1212
13+ from vectordb_bench .backend .filter import Filter , FilterOp
1314from ..api import VectorDB
1415from .config import PgDiskANNConfigDict , PgDiskANNIndexConfig
1516
1920class PgDiskANN (VectorDB ):
2021 """Use psycopg instructions"""
2122
23+ supported_filter_types : list [FilterOp ] = [
24+ FilterOp .NonFilter ,
25+ FilterOp .NumGE ,
26+ FilterOp .StrEqual ,
27+ ]
28+
2229 conn : psycopg .Connection [Any ] | None = None
23- coursor : psycopg .Cursor [Any ] | None = None
30+ cursor : psycopg .Cursor [Any ] | None = None
2431
2532 _filtered_search : sql .Composed
2633 _unfiltered_search : sql .Composed
@@ -32,13 +39,18 @@ def __init__(
3239 db_case_config : PgDiskANNIndexConfig ,
3340 collection_name : str = "pg_diskann_collection" ,
3441 drop_old : bool = False ,
42+ with_scalar_labels : bool = False ,
3543 ** kwargs ,
3644 ):
3745 self .name = "PgDiskANN"
3846 self .db_config = db_config
3947 self .case_config = db_case_config
4048 self .table_name = collection_name
4149 self .dim = dim
50+ self .with_scalar_labels = with_scalar_labels
51+ self ._scalar_label_field = "label"
52+ self .filter_op = None
53+ self .filter_value = None
4254
4355 self ._index_name = "pgdiskann_index"
4456 self ._primary_field = "id"
@@ -304,11 +316,23 @@ def _create_table(self, dim: int):
304316 try :
305317 log .info (f"{ self .name } client create table : { self .table_name } " )
306318
307- self .cursor .execute (
308- sql .SQL (
309- "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));" ,
310- ).format (table_name = sql .Identifier (self .table_name ), dim = dim ),
311- )
319+ if self .with_scalar_labels :
320+ # Create table WITH label column
321+ self .cursor .execute (
322+ sql .SQL (
323+ """
324+ CREATE TABLE IF NOT EXISTS public.{table_name}
325+ (id BIGINT PRIMARY KEY, embedding vector({dim}), label VARCHAR(64));
326+ """ ,
327+ ).format (table_name = sql .Identifier (self .table_name ), dim = dim ),
328+ )
329+ else :
330+ # Create table WITHOUT label column (existing behavior)
331+ self .cursor .execute (
332+ sql .SQL (
333+ "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));" ,
334+ ).format (table_name = sql .Identifier (self .table_name ), dim = dim ),
335+ )
312336 self .conn .commit ()
313337 except Exception as e :
314338 log .warning (f"Failed to create pgdiskann table: { self .table_name } error: { e } " )
@@ -318,11 +342,15 @@ def insert_embeddings(
318342 self ,
319343 embeddings : list [list [float ]],
320344 metadata : list [int ],
345+ labels_data : list [str ] | None = None ,
321346 ** kwargs : Any ,
322347 ) -> tuple [int , Exception | None ]:
323348 assert self .conn is not None , "Connection is not initialized"
324349 assert self .cursor is not None , "Cursor is not initialized"
325350
351+ if self .with_scalar_labels :
352+ assert labels_data is not None , "labels_data should be provided if with_scalar_labels is set to True"
353+
326354 try :
327355 metadata_arr = np .array (metadata )
328356 embeddings_arr = np .array (embeddings )
@@ -332,9 +360,14 @@ def insert_embeddings(
332360 table_name = sql .Identifier (self .table_name ),
333361 ),
334362 ) as copy :
335- copy .set_types (["bigint" , "vector" ])
336- for i , row in enumerate (metadata_arr ):
337- copy .write_row ((row , embeddings_arr [i ]))
363+ if self .with_scalar_labels :
364+ copy .set_types (["bigint" , "vector" , "varchar" ])
365+ for i , row in enumerate (metadata_arr ):
366+ copy .write_row ((row , embeddings_arr [i ], labels_data [i ]))
367+ else :
368+ copy .set_types (["bigint" , "vector" ])
369+ for i , row in enumerate (metadata_arr ):
370+ copy .write_row ((row , embeddings_arr [i ]))
338371 self .conn .commit ()
339372
340373 if kwargs .get ("last_batch" ):
@@ -345,49 +378,123 @@ def insert_embeddings(
345378 log .warning (f"Failed to insert data into table ({ self .table_name } ), error: { e } " )
346379 return 0 , e
347380
381+ def prepare_filter (self , filters : Filter ):
382+ """Prepare filter for label or integer-based filtering"""
383+ from vectordb_bench .backend .filter import FilterOp
384+
385+ if filters .type == FilterOp .NonFilter :
386+ self .filter_op = FilterOp .NonFilter
387+ self .filter_value = None
388+
389+ elif filters .type == FilterOp .NumGE :
390+ # Integer filtering: WHERE id >= X
391+ self .filter_op = FilterOp .NumGE
392+ self .filter_value = filters .int_value
393+
394+ elif filters .type == FilterOp .StrEqual :
395+ # Label filtering: WHERE label = 'label_1p'
396+ self .filter_op = FilterOp .StrEqual
397+ self .filter_value = filters .label_value
398+ else :
399+ msg = f"Not support Filter for PgDiskANN - { filters } "
400+ raise ValueError (msg )
401+
402+
348403 def search_embedding (
349404 self ,
350405 query : list [float ],
351406 k : int = 100 ,
352- filters : dict | None = None ,
407+ filters : dict | None = None ,
353408 timeout : int | None = None ,
354409 ) -> list [int ]:
355410 assert self .conn is not None , "Connection is not initialized"
356411 assert self .cursor is not None , "Cursor is not initialized"
357412
413+ from vectordb_bench .backend .filter import FilterOp
414+
358415 search_params = self .case_config .search_param ()
359416 is_reranking = search_params .get ("reranking" , False )
360-
417+
361418 q = np .asarray (query )
362- if filters :
363- gt = filters .get ("id" )
419+
420+ # Build the appropriate query based on filter_op
421+ if self .filter_op == FilterOp .StrEqual :
422+ # Label filtering: e.g. WHERE label = 'label_1p'
423+ if is_reranking :
424+ query_sql = sql .SQL ("""
425+ SELECT i.id
426+ FROM (
427+ SELECT id, embedding
428+ FROM public.{table_name}
429+ WHERE {label_field} = %s
430+ ORDER BY embedding {metric_fun_op} %s::vector
431+ LIMIT {quantized_fetch_limit}::int
432+ ) i
433+ ORDER BY i.embedding {reranking_metric_fun_op} %s::vector
434+ LIMIT %s::int
435+ """ ).format (
436+ table_name = sql .Identifier (self .table_name ),
437+ label_field = sql .Identifier (self ._scalar_label_field ),
438+ metric_fun_op = sql .SQL (search_params ["metric_fun_op" ]),
439+ reranking_metric_fun_op = sql .SQL (search_params ["reranking_metric_fun_op" ]),
440+ quantized_fetch_limit = sql .Literal (search_params ["quantized_fetch_limit" ]),
441+ )
442+ result = self .cursor .execute (
443+ query_sql ,
444+ (self .filter_value , q , q , k ),
445+ prepare = True ,
446+ binary = True ,
447+ )
448+ else :
449+ query_sql = sql .Composed ([
450+ sql .SQL (
451+ "SELECT id FROM public.{table_name} WHERE {label_field} = %s ORDER BY embedding " ,
452+ ).format (
453+ table_name = sql .Identifier (self .table_name ),
454+ label_field = sql .Identifier (self ._scalar_label_field ),
455+ ),
456+ sql .SQL (search_params ["metric_fun_op" ]),
457+ sql .SQL (" %s::vector LIMIT %s::int" ),
458+ ])
459+ result = self .cursor .execute (
460+ query_sql ,
461+ (self .filter_value , q , k ),
462+ prepare = True ,
463+ binary = True ,
464+ )
465+
466+ elif self .filter_op == FilterOp .NumGE :
467+ # Integer filtering: WHERE id >= X (existing behavior)
364468 if is_reranking :
365469 result = self .cursor .execute (
366470 self ._filtered_search ,
367- (gt , q , q , k ),
471+ (self . filter_value , q , q , k ),
368472 prepare = True ,
369473 binary = True ,
370474 )
371475 else :
372476 result = self .cursor .execute (
373477 self ._filtered_search ,
374- (gt , q , k ),
478+ (self . filter_value , q , k ),
375479 prepare = True ,
376480 binary = True ,
377481 )
378- elif is_reranking :
379- result = self .cursor .execute (
380- self ._unfiltered_search ,
381- (q , q , k ),
382- prepare = True ,
383- binary = True ,
384- )
482+
385483 else :
386- result = self .cursor .execute (
387- self ._unfiltered_search ,
388- (q , k ),
389- prepare = True ,
390- binary = True ,
391- )
484+ # No filtering (existing behavior)
485+ if is_reranking :
486+ result = self .cursor .execute (
487+ self ._unfiltered_search ,
488+ (q , q , k ),
489+ prepare = True ,
490+ binary = True ,
491+ )
492+ else :
493+ result = self .cursor .execute (
494+ self ._unfiltered_search ,
495+ (q , k ),
496+ prepare = True ,
497+ binary = True ,
498+ )
392499
393500 return [int (i [0 ]) for i in result .fetchall ()]
0 commit comments