Skip to content

Commit 63cc50a

Browse files
Feat: Add label filter support in pgdiskann client (#724)
* Add label filtering support to pgdiskann client * Refactor pgdiskann filtering logic * Refactor: remove unrelated function * style: apply black formatting to pgdiskann.py * fix: remove trailing whitespace and fix import sorting * docs: add comments for label naming and vector storage optimization * Revert "docs: add comments for label naming and vector storage optimization" This reverts commit d10b296. --------- Co-authored-by: Eesha Faisal <eesha.faisal@emumba.com>
1 parent 02e5d33 commit 63cc50a

1 file changed

Lines changed: 106 additions & 98 deletions

File tree

vectordb_bench/backend/clients/pgdiskann/pgdiskann.py

Lines changed: 106 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from pgvector.psycopg import register_vector
1111
from psycopg import Connection, Cursor, sql
1212

13+
from vectordb_bench.backend.filter import Filter, FilterOp
14+
1315
from ..api import VectorDB
1416
from .config import PgDiskANNConfigDict, PgDiskANNIndexConfig
1517

@@ -19,11 +21,16 @@
1921
class PgDiskANN(VectorDB):
2022
"""Use psycopg instructions"""
2123

24+
supported_filter_types: list[FilterOp] = [
25+
FilterOp.NonFilter,
26+
FilterOp.NumGE,
27+
FilterOp.StrEqual,
28+
]
29+
2230
conn: psycopg.Connection[Any] | None = None
23-
coursor: psycopg.Cursor[Any] | None = None
31+
cursor: psycopg.Cursor[Any] | None = None
2432

25-
_filtered_search: sql.Composed
26-
_unfiltered_search: sql.Composed
33+
_search: sql.Composed
2734

2835
def __init__(
2936
self,
@@ -32,13 +39,17 @@ def __init__(
3239
db_case_config: PgDiskANNIndexConfig,
3340
collection_name: str = "pg_diskann_collection",
3441
drop_old: bool = False,
42+
with_scalar_labels: bool = False,
3543
**kwargs,
3644
):
3745
self.name = "PgDiskANN"
3846
self.db_config = db_config
3947
self.case_config = db_case_config
4048
self.table_name = collection_name
4149
self.dim = dim
50+
self.with_scalar_labels = with_scalar_labels
51+
self._scalar_label_field = "label"
52+
self.where_clause = ""
4253

4354
self._index_name = "pgdiskann_index"
4455
self._primary_field = "id"
@@ -86,83 +97,58 @@ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
8697

8798
return conn, cursor
8899

89-
@contextmanager
90-
def init(self) -> Generator[None, None, None]:
91-
self.conn, self.cursor = self._create_connection(**self.db_config)
92-
93-
session_options: dict[str, Any] = self.case_config.session_param()
94-
95-
if len(session_options) > 0:
96-
for setting_name, setting_val in session_options.items():
97-
command = sql.SQL("SET {setting_name} = {setting_val};").format(
98-
setting_name=sql.Identifier(setting_name), setting_val=sql.Literal(setting_val)
99-
)
100-
log.debug(command.as_string(self.cursor))
101-
self.cursor.execute(command)
102-
self.conn.commit()
103-
100+
def _generate_search_query(self) -> sql.Composed:
101+
"""Generate search query with where_clause placeholder"""
104102
search_params = self.case_config.search_param()
105103

106104
if search_params.get("reranking"):
107-
# Reranking-enabled queries
108-
self._filtered_search = sql.SQL("""
105+
search_query = sql.SQL("""
109106
SELECT i.id
110107
FROM (
111108
SELECT id, embedding
112109
FROM public.{table_name}
113-
WHERE id >= %s
110+
{where_clause}
114111
ORDER BY embedding {metric_fun_op} %s::vector
115112
LIMIT {quantized_fetch_limit}::int
116113
) i
117114
ORDER BY i.embedding {reranking_metric_fun_op} %s::vector
118115
LIMIT %s::int
119-
""").format(
116+
""").format(
120117
table_name=sql.Identifier(self.table_name),
118+
where_clause=sql.SQL(self.where_clause),
121119
metric_fun_op=sql.SQL(search_params["metric_fun_op"]),
122120
reranking_metric_fun_op=sql.SQL(search_params["reranking_metric_fun_op"]),
123121
quantized_fetch_limit=sql.Literal(search_params["quantized_fetch_limit"]),
124122
)
125-
126-
self._unfiltered_search = sql.SQL("""
127-
SELECT i.id
128-
FROM (
129-
SELECT id, embedding
130-
FROM public.{table_name}
131-
ORDER BY embedding {metric_fun_op} %s::vector
132-
LIMIT {quantized_fetch_limit}::int
133-
) i
134-
ORDER BY i.embedding {reranking_metric_fun_op} %s::vector
135-
LIMIT %s::int
136-
""").format(
137-
table_name=sql.Identifier(self.table_name),
138-
metric_fun_op=sql.SQL(search_params["metric_fun_op"]),
139-
reranking_metric_fun_op=sql.SQL(search_params["reranking_metric_fun_op"]),
140-
quantized_fetch_limit=sql.Literal(search_params["quantized_fetch_limit"]),
141-
)
142-
143123
else:
144-
self._filtered_search = sql.Composed(
124+
search_query = sql.Composed(
145125
[
146-
sql.SQL(
147-
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ",
148-
).format(table_name=sql.Identifier(self.table_name)),
149-
sql.SQL(search_params["metric_fun_op"]),
150-
sql.SQL(" %s::vector LIMIT %s::int"),
151-
]
152-
)
153-
154-
self._unfiltered_search = sql.Composed(
155-
[
156-
sql.SQL("SELECT id FROM public.{table_name} ORDER BY embedding ").format(
157-
table_name=sql.Identifier(self.table_name)
126+
sql.SQL("SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding ").format(
127+
table_name=sql.Identifier(self.table_name),
128+
where_clause=sql.SQL(self.where_clause),
158129
),
159130
sql.SQL(search_params["metric_fun_op"]),
160131
sql.SQL(" %s::vector LIMIT %s::int"),
161132
]
162133
)
163134

164-
log.debug(f"Unfiltered search query={self._unfiltered_search.as_string(self.conn)}")
165-
log.debug(f"Filtered search query={self._filtered_search.as_string(self.conn)}")
135+
return search_query
136+
137+
@contextmanager
138+
def init(self) -> Generator[None, None, None]:
139+
self.conn, self.cursor = self._create_connection(**self.db_config)
140+
141+
session_options: dict[str, Any] = self.case_config.session_param()
142+
143+
if len(session_options) > 0:
144+
for setting_name, setting_val in session_options.items():
145+
command = sql.SQL("SET {setting_name} = {setting_val};").format(
146+
setting_name=sql.Identifier(setting_name),
147+
setting_val=sql.Literal(setting_val),
148+
)
149+
log.debug(command.as_string(self.cursor))
150+
self.cursor.execute(command)
151+
self.conn.commit()
166152

167153
try:
168154
yield
@@ -281,12 +267,10 @@ def _create_index(self):
281267

282268
with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
283269

284-
index_create_sql = sql.SQL(
285-
"""
270+
index_create_sql = sql.SQL("""
286271
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
287272
USING {index_type} (embedding {embedding_metric})
288-
""",
289-
).format(
273+
""").format(
290274
index_name=sql.Identifier(self._index_name),
291275
table_name=sql.Identifier(self.table_name),
292276
index_type=sql.Identifier(index_param["index_type"].lower()),
@@ -304,11 +288,36 @@ def _create_table(self, dim: int):
304288
try:
305289
log.info(f"{self.name} client create table : {self.table_name}")
306290

291+
if self.with_scalar_labels:
292+
self.cursor.execute(
293+
sql.SQL("""
294+
CREATE TABLE IF NOT EXISTS public.{table_name}
295+
({primary_field} BIGINT PRIMARY KEY, embedding vector({dim}), {label_field} VARCHAR(64));
296+
""").format(
297+
table_name=sql.Identifier(self.table_name),
298+
dim=dim,
299+
primary_field=sql.Identifier(self._primary_field),
300+
label_field=sql.Identifier(self._scalar_label_field),
301+
),
302+
)
303+
else:
304+
self.cursor.execute(
305+
sql.SQL("""
306+
CREATE TABLE IF NOT EXISTS public.{table_name}
307+
({primary_field} BIGINT PRIMARY KEY, embedding vector({dim}));
308+
""").format(
309+
table_name=sql.Identifier(self.table_name),
310+
dim=dim,
311+
primary_field=sql.Identifier(self._primary_field),
312+
),
313+
)
314+
307315
self.cursor.execute(
308-
sql.SQL(
309-
"CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
310-
).format(table_name=sql.Identifier(self.table_name), dim=dim),
316+
sql.SQL("ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;").format(
317+
table_name=sql.Identifier(self.table_name)
318+
),
311319
)
320+
312321
self.conn.commit()
313322
except Exception as e:
314323
log.warning(f"Failed to create pgdiskann table: {self.table_name} error: {e}")
@@ -318,11 +327,15 @@ def insert_embeddings(
318327
self,
319328
embeddings: list[list[float]],
320329
metadata: list[int],
330+
labels_data: list[str] | None = None,
321331
**kwargs: Any,
322332
) -> tuple[int, Exception | None]:
323333
assert self.conn is not None, "Connection is not initialized"
324334
assert self.cursor is not None, "Cursor is not initialized"
325335

336+
if self.with_scalar_labels:
337+
assert labels_data is not None, "labels_data should be provided if with_scalar_labels is set to True"
338+
326339
try:
327340
metadata_arr = np.array(metadata)
328341
embeddings_arr = np.array(embeddings)
@@ -332,9 +345,14 @@ def insert_embeddings(
332345
table_name=sql.Identifier(self.table_name),
333346
),
334347
) as copy:
335-
copy.set_types(["bigint", "vector"])
336-
for i, row in enumerate(metadata_arr):
337-
copy.write_row((row, embeddings_arr[i]))
348+
if self.with_scalar_labels:
349+
copy.set_types(["bigint", "vector", "varchar"])
350+
for i, row in enumerate(metadata_arr):
351+
copy.write_row((row, embeddings_arr[i], labels_data[i]))
352+
else:
353+
copy.set_types(["bigint", "vector"])
354+
for i, row in enumerate(metadata_arr):
355+
copy.write_row((row, embeddings_arr[i]))
338356
self.conn.commit()
339357

340358
if kwargs.get("last_batch"):
@@ -345,49 +363,39 @@ def insert_embeddings(
345363
log.warning(f"Failed to insert data into table ({self.table_name}), error: {e}")
346364
return 0, e
347365

366+
def prepare_filter(self, filters: Filter):
367+
"""Prepare filter - builds where_clause"""
368+
if filters.type == FilterOp.NonFilter:
369+
self.where_clause = ""
370+
elif filters.type == FilterOp.NumGE:
371+
self.where_clause = f"WHERE {self._primary_field} >= {filters.int_value}"
372+
elif filters.type == FilterOp.StrEqual:
373+
self.where_clause = f"WHERE {self._scalar_label_field} = '{filters.label_value}'"
374+
else:
375+
msg = f"Not support Filter for PgDiskANN - {filters}"
376+
raise ValueError(msg)
377+
378+
self._search = self._generate_search_query()
379+
log.debug(f"Search query={self._search.as_string(self.conn)}")
380+
348381
def search_embedding(
349382
self,
350383
query: list[float],
351384
k: int = 100,
352-
filters: dict | None = None,
353385
timeout: int | None = None,
386+
**kwargs: Any,
354387
) -> list[int]:
355388
assert self.conn is not None, "Connection is not initialized"
356389
assert self.cursor is not None, "Cursor is not initialized"
357390

358391
search_params = self.case_config.search_param()
359-
is_reranking = search_params.get("reranking", False)
360-
361392
q = np.asarray(query)
362-
if filters:
363-
gt = filters.get("id")
364-
if is_reranking:
365-
result = self.cursor.execute(
366-
self._filtered_search,
367-
(gt, q, q, k),
368-
prepare=True,
369-
binary=True,
370-
)
371-
else:
372-
result = self.cursor.execute(
373-
self._filtered_search,
374-
(gt, q, k),
375-
prepare=True,
376-
binary=True,
377-
)
378-
elif is_reranking:
379-
result = self.cursor.execute(
380-
self._unfiltered_search,
381-
(q, q, k),
382-
prepare=True,
383-
binary=True,
384-
)
385-
else:
386-
result = self.cursor.execute(
387-
self._unfiltered_search,
388-
(q, k),
389-
prepare=True,
390-
binary=True,
391-
)
393+
394+
result = self.cursor.execute(
395+
self._search,
396+
(q, q, k) if search_params.get("reranking", False) else (q, k),
397+
prepare=True,
398+
binary=True,
399+
)
392400

393401
return [int(i[0]) for i in result.fetchall()]

0 commit comments

Comments
 (0)