Skip to content

Commit 2e35a4d

Browse files
Add label filtering support to pgdiskann client
1 parent d035c62 commit 2e35a4d

1 file changed

Lines changed: 135 additions & 28 deletions

File tree

vectordb_bench/backend/clients/pgdiskann/pgdiskann.py

Lines changed: 135 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pgvector.psycopg import register_vector
1111
from psycopg import Connection, Cursor, sql
1212

13+
from vectordb_bench.backend.filter import Filter, FilterOp
1314
from ..api import VectorDB
1415
from .config import PgDiskANNConfigDict, PgDiskANNIndexConfig
1516

@@ -19,8 +20,14 @@
1920
class PgDiskANN(VectorDB):
2021
"""Use psycopg instructions"""
2122

23+
supported_filter_types: list[FilterOp] = [
24+
FilterOp.NonFilter,
25+
FilterOp.NumGE,
26+
FilterOp.StrEqual,
27+
]
28+
2229
conn: psycopg.Connection[Any] | None = None
23-
coursor: psycopg.Cursor[Any] | None = None
30+
cursor: psycopg.Cursor[Any] | None = None
2431

2532
_filtered_search: sql.Composed
2633
_unfiltered_search: sql.Composed
@@ -32,13 +39,18 @@ def __init__(
3239
db_case_config: PgDiskANNIndexConfig,
3340
collection_name: str = "pg_diskann_collection",
3441
drop_old: bool = False,
42+
with_scalar_labels: bool = False,
3543
**kwargs,
3644
):
3745
self.name = "PgDiskANN"
3846
self.db_config = db_config
3947
self.case_config = db_case_config
4048
self.table_name = collection_name
4149
self.dim = dim
50+
self.with_scalar_labels = with_scalar_labels
51+
self._scalar_label_field = "label"
52+
self.filter_op = None
53+
self.filter_value = None
4254

4355
self._index_name = "pgdiskann_index"
4456
self._primary_field = "id"
@@ -304,11 +316,23 @@ def _create_table(self, dim: int):
304316
try:
305317
log.info(f"{self.name} client create table : {self.table_name}")
306318

307-
self.cursor.execute(
308-
sql.SQL(
309-
"CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
310-
).format(table_name=sql.Identifier(self.table_name), dim=dim),
311-
)
319+
if self.with_scalar_labels:
320+
# Create table WITH label column
321+
self.cursor.execute(
322+
sql.SQL(
323+
"""
324+
CREATE TABLE IF NOT EXISTS public.{table_name}
325+
(id BIGINT PRIMARY KEY, embedding vector({dim}), label VARCHAR(64));
326+
""",
327+
).format(table_name=sql.Identifier(self.table_name), dim=dim),
328+
)
329+
else:
330+
# Create table WITHOUT label column (existing behavior)
331+
self.cursor.execute(
332+
sql.SQL(
333+
"CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
334+
).format(table_name=sql.Identifier(self.table_name), dim=dim),
335+
)
312336
self.conn.commit()
313337
except Exception as e:
314338
log.warning(f"Failed to create pgdiskann table: {self.table_name} error: {e}")
@@ -318,11 +342,15 @@ def insert_embeddings(
318342
self,
319343
embeddings: list[list[float]],
320344
metadata: list[int],
345+
labels_data: list[str] | None = None,
321346
**kwargs: Any,
322347
) -> tuple[int, Exception | None]:
323348
assert self.conn is not None, "Connection is not initialized"
324349
assert self.cursor is not None, "Cursor is not initialized"
325350

351+
if self.with_scalar_labels:
352+
assert labels_data is not None, "labels_data should be provided if with_scalar_labels is set to True"
353+
326354
try:
327355
metadata_arr = np.array(metadata)
328356
embeddings_arr = np.array(embeddings)
@@ -332,9 +360,14 @@ def insert_embeddings(
332360
table_name=sql.Identifier(self.table_name),
333361
),
334362
) as copy:
335-
copy.set_types(["bigint", "vector"])
336-
for i, row in enumerate(metadata_arr):
337-
copy.write_row((row, embeddings_arr[i]))
363+
if self.with_scalar_labels:
364+
copy.set_types(["bigint", "vector", "varchar"])
365+
for i, row in enumerate(metadata_arr):
366+
copy.write_row((row, embeddings_arr[i], labels_data[i]))
367+
else:
368+
copy.set_types(["bigint", "vector"])
369+
for i, row in enumerate(metadata_arr):
370+
copy.write_row((row, embeddings_arr[i]))
338371
self.conn.commit()
339372

340373
if kwargs.get("last_batch"):
@@ -345,49 +378,123 @@ def insert_embeddings(
345378
log.warning(f"Failed to insert data into table ({self.table_name}), error: {e}")
346379
return 0, e
347380

381+
def prepare_filter(self, filters: Filter):
382+
"""Prepare filter for label or integer-based filtering"""
383+
from vectordb_bench.backend.filter import FilterOp
384+
385+
if filters.type == FilterOp.NonFilter:
386+
self.filter_op = FilterOp.NonFilter
387+
self.filter_value = None
388+
389+
elif filters.type == FilterOp.NumGE:
390+
# Integer filtering: WHERE id >= X
391+
self.filter_op = FilterOp.NumGE
392+
self.filter_value = filters.int_value
393+
394+
elif filters.type == FilterOp.StrEqual:
395+
# Label filtering: WHERE label = 'label_1p'
396+
self.filter_op = FilterOp.StrEqual
397+
self.filter_value = filters.label_value
398+
else:
399+
msg = f"Not support Filter for PgDiskANN - {filters}"
400+
raise ValueError(msg)
401+
402+
348403
def search_embedding(
349404
self,
350405
query: list[float],
351406
k: int = 100,
352-
filters: dict | None = None,
407+
filters: dict | None = None,
353408
timeout: int | None = None,
354409
) -> list[int]:
355410
assert self.conn is not None, "Connection is not initialized"
356411
assert self.cursor is not None, "Cursor is not initialized"
357412

413+
from vectordb_bench.backend.filter import FilterOp
414+
358415
search_params = self.case_config.search_param()
359416
is_reranking = search_params.get("reranking", False)
360-
417+
361418
q = np.asarray(query)
362-
if filters:
363-
gt = filters.get("id")
419+
420+
# Build the appropriate query based on filter_op
421+
if self.filter_op == FilterOp.StrEqual:
422+
# Label filtering: e.g. WHERE label = 'label_1p'
423+
if is_reranking:
424+
query_sql = sql.SQL("""
425+
SELECT i.id
426+
FROM (
427+
SELECT id, embedding
428+
FROM public.{table_name}
429+
WHERE {label_field} = %s
430+
ORDER BY embedding {metric_fun_op} %s::vector
431+
LIMIT {quantized_fetch_limit}::int
432+
) i
433+
ORDER BY i.embedding {reranking_metric_fun_op} %s::vector
434+
LIMIT %s::int
435+
""").format(
436+
table_name=sql.Identifier(self.table_name),
437+
label_field=sql.Identifier(self._scalar_label_field),
438+
metric_fun_op=sql.SQL(search_params["metric_fun_op"]),
439+
reranking_metric_fun_op=sql.SQL(search_params["reranking_metric_fun_op"]),
440+
quantized_fetch_limit=sql.Literal(search_params["quantized_fetch_limit"]),
441+
)
442+
result = self.cursor.execute(
443+
query_sql,
444+
(self.filter_value, q, q, k),
445+
prepare=True,
446+
binary=True,
447+
)
448+
else:
449+
query_sql = sql.Composed([
450+
sql.SQL(
451+
"SELECT id FROM public.{table_name} WHERE {label_field} = %s ORDER BY embedding ",
452+
).format(
453+
table_name=sql.Identifier(self.table_name),
454+
label_field=sql.Identifier(self._scalar_label_field),
455+
),
456+
sql.SQL(search_params["metric_fun_op"]),
457+
sql.SQL(" %s::vector LIMIT %s::int"),
458+
])
459+
result = self.cursor.execute(
460+
query_sql,
461+
(self.filter_value, q, k),
462+
prepare=True,
463+
binary=True,
464+
)
465+
466+
elif self.filter_op == FilterOp.NumGE:
467+
# Integer filtering: WHERE id >= X (existing behavior)
364468
if is_reranking:
365469
result = self.cursor.execute(
366470
self._filtered_search,
367-
(gt, q, q, k),
471+
(self.filter_value, q, q, k),
368472
prepare=True,
369473
binary=True,
370474
)
371475
else:
372476
result = self.cursor.execute(
373477
self._filtered_search,
374-
(gt, q, k),
478+
(self.filter_value, q, k),
375479
prepare=True,
376480
binary=True,
377481
)
378-
elif is_reranking:
379-
result = self.cursor.execute(
380-
self._unfiltered_search,
381-
(q, q, k),
382-
prepare=True,
383-
binary=True,
384-
)
482+
385483
else:
386-
result = self.cursor.execute(
387-
self._unfiltered_search,
388-
(q, k),
389-
prepare=True,
390-
binary=True,
391-
)
484+
# No filtering (existing behavior)
485+
if is_reranking:
486+
result = self.cursor.execute(
487+
self._unfiltered_search,
488+
(q, q, k),
489+
prepare=True,
490+
binary=True,
491+
)
492+
else:
493+
result = self.cursor.execute(
494+
self._unfiltered_search,
495+
(q, k),
496+
prepare=True,
497+
binary=True,
498+
)
392499

393500
return [int(i[0]) for i in result.fetchall()]

0 commit comments

Comments
 (0)