Skip to content

Commit f253a41

Browse files
authored
Added support for Product Quantization in pg_diskann (zilliztech#579)
* Added support for PQ in pgdiskann Update list of allowed reranking metric Update list of allowed reranking metric * Run formatter
1 parent 1a4023a commit f253a41

5 files changed

Lines changed: 240 additions & 29 deletions

File tree

vectordb_bench/backend/clients/pgdiskann/cli.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pydantic import SecretStr
66

77
from vectordb_bench.backend.clients import DB
8+
from vectordb_bench.backend.clients.api import MetricType
89

910
from ....cli.cli import (
1011
CommonTypedDict,
@@ -48,6 +49,15 @@ class PgDiskAnnTypedDict(CommonTypedDict):
4849
help="PgDiskAnn l_value_ib",
4950
),
5051
]
52+
pq_param_num_chunks: Annotated[
53+
int,
54+
click.option(
55+
"--pq-param-num-chunks",
56+
type=int,
57+
help="PgDiskAnn pq_param_num_chunks",
58+
required=False,
59+
),
60+
]
5161
l_value_is: Annotated[
5262
float,
5363
click.option(
@@ -56,6 +66,37 @@ class PgDiskAnnTypedDict(CommonTypedDict):
5666
help="PgDiskAnn l_value_is",
5767
),
5868
]
69+
reranking: Annotated[
70+
bool | None,
71+
click.option(
72+
"--reranking/--skip-reranking",
73+
type=bool,
74+
help="Enable reranking for PQ search",
75+
default=False,
76+
),
77+
]
78+
reranking_metric: Annotated[
79+
str | None,
80+
click.option(
81+
"--reranking-metric",
82+
type=click.Choice(
83+
[metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD", "DP"]],
84+
),
85+
help="Distance metric for reranking",
86+
default="COSINE",
87+
show_default=True,
88+
required=False,
89+
),
90+
]
91+
quantized_fetch_limit: Annotated[
92+
int | None,
93+
click.option(
94+
"--quantized-fetch-limit",
95+
type=int,
96+
help="Limit of inner query in case of reranking",
97+
required=False,
98+
),
99+
]
59100
maintenance_work_mem: Annotated[
60101
str | None,
61102
click.option(
@@ -98,7 +139,11 @@ def PgDiskAnn(
98139
db_case_config=PgDiskANNImplConfig(
99140
max_neighbors=parameters["max_neighbors"],
100141
l_value_ib=parameters["l_value_ib"],
142+
pq_param_num_chunks=parameters["pq_param_num_chunks"],
101143
l_value_is=parameters["l_value_is"],
144+
reranking=parameters["reranking"],
145+
reranking_metric=parameters["reranking_metric"],
146+
quantized_fetch_limit=parameters["quantized_fetch_limit"],
102147
max_parallel_workers=parameters["max_parallel_workers"],
103148
maintenance_work_mem=parameters["maintenance_work_mem"],
104149
),

vectordb_bench/backend/clients/pgdiskann/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ def parse_metric_fun_op(self) -> LiteralString:
6060
return "<#>"
6161
return "<=>"
6262

63+
def parse_reranking_metric_fun_op(self) -> LiteralString:
64+
if self.reranking_metric == MetricType.L2:
65+
return "<->"
66+
if self.reranking_metric == MetricType.IP:
67+
return "<#>"
68+
return "<=>"
69+
6370
def parse_metric_fun_str(self) -> str:
6471
if self.metric_type == MetricType.L2:
6572
return "l2_distance"
@@ -115,7 +122,11 @@ class PgDiskANNImplConfig(PgDiskANNIndexConfig):
115122
index: IndexType = IndexType.DISKANN
116123
max_neighbors: int | None
117124
l_value_ib: int | None
125+
pq_param_num_chunks: int | None
118126
l_value_is: float | None
127+
reranking: bool | None = None
128+
reranking_metric: str | None = None
129+
quantized_fetch_limit: int | None = None
119130
maintenance_work_mem: str | None = None
120131
max_parallel_workers: int | None = None
121132

@@ -126,6 +137,8 @@ def index_param(self) -> dict:
126137
"options": {
127138
"max_neighbors": self.max_neighbors,
128139
"l_value_ib": self.l_value_ib,
140+
"pq_param_num_chunks": self.pq_param_num_chunks,
141+
"product_quantized": str(self.reranking),
129142
},
130143
"maintenance_work_mem": self.maintenance_work_mem,
131144
"max_parallel_workers": self.max_parallel_workers,
@@ -135,6 +148,9 @@ def search_param(self) -> dict:
135148
return {
136149
"metric": self.parse_metric(),
137150
"metric_fun_op": self.parse_metric_fun_op(),
151+
"reranking": self.reranking,
152+
"reranking_metric_fun_op": self.parse_reranking_metric_fun_op(),
153+
"quantized_fetch_limit": self.quantized_fetch_limit,
138154
}
139155

140156
def session_param(self) -> dict:

vectordb_bench/backend/clients/pgdiskann/pgdiskann.py

Lines changed: 94 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -90,38 +90,83 @@ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
9090
def init(self) -> Generator[None, None, None]:
9191
self.conn, self.cursor = self._create_connection(**self.db_config)
9292

93-
# index configuration may have commands defined that we should set during each client session
9493
session_options: dict[str, Any] = self.case_config.session_param()
9594

9695
if len(session_options) > 0:
9796
for setting_name, setting_val in session_options.items():
98-
command = sql.SQL("SET {setting_name} " + "= {setting_val};").format(
99-
setting_name=sql.Identifier(setting_name),
100-
setting_val=sql.Identifier(str(setting_val)),
97+
command = sql.SQL("SET {setting_name} = {setting_val};").format(
98+
setting_name=sql.Identifier(setting_name), setting_val=sql.Literal(setting_val)
10199
)
102100
log.debug(command.as_string(self.cursor))
103101
self.cursor.execute(command)
104102
self.conn.commit()
105103

106-
self._filtered_search = sql.Composed(
107-
[
108-
sql.SQL(
109-
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ",
110-
).format(table_name=sql.Identifier(self.table_name)),
111-
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
112-
sql.SQL(" %s::vector LIMIT %s::int"),
113-
],
114-
)
104+
search_params = self.case_config.search_param()
105+
106+
if search_params.get("reranking"):
107+
# Reranking-enabled queries
108+
self._filtered_search = sql.SQL(
109+
"""
110+
SELECT i.id
111+
FROM (
112+
SELECT id, embedding
113+
FROM public.{table_name}
114+
WHERE id >= %s
115+
ORDER BY embedding {metric_fun_op} %s::vector
116+
LIMIT {quantized_fetch_limit}::int
117+
) i
118+
ORDER BY i.embedding {reranking_metric_fun_op} %s::vector
119+
LIMIT %s::int
120+
"""
121+
).format(
122+
table_name=sql.Identifier(self.table_name),
123+
metric_fun_op=sql.SQL(search_params["metric_fun_op"]),
124+
reranking_metric_fun_op=sql.SQL(search_params["reranking_metric_fun_op"]),
125+
quantized_fetch_limit=sql.Literal(search_params["quantized_fetch_limit"]),
126+
)
115127

116-
self._unfiltered_search = sql.Composed(
117-
[
118-
sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
119-
sql.Identifier(self.table_name),
120-
),
121-
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
122-
sql.SQL(" %s::vector LIMIT %s::int"),
123-
],
124-
)
128+
self._unfiltered_search = sql.SQL(
129+
"""
130+
SELECT i.id
131+
FROM (
132+
SELECT id, embedding
133+
FROM public.{table_name}
134+
ORDER BY embedding {metric_fun_op} %s::vector
135+
LIMIT {quantized_fetch_limit}::int
136+
) i
137+
ORDER BY i.embedding {reranking_metric_fun_op} %s::vector
138+
LIMIT %s::int
139+
"""
140+
).format(
141+
table_name=sql.Identifier(self.table_name),
142+
metric_fun_op=sql.SQL(search_params["metric_fun_op"]),
143+
reranking_metric_fun_op=sql.SQL(search_params["reranking_metric_fun_op"]),
144+
quantized_fetch_limit=sql.Literal(search_params["quantized_fetch_limit"]),
145+
)
146+
147+
else:
148+
self._filtered_search = sql.Composed(
149+
[
150+
sql.SQL(
151+
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ",
152+
).format(table_name=sql.Identifier(self.table_name)),
153+
sql.SQL(search_params["metric_fun_op"]),
154+
sql.SQL(" %s::vector LIMIT %s::int"),
155+
]
156+
)
157+
158+
self._unfiltered_search = sql.Composed(
159+
[
160+
sql.SQL("SELECT id FROM public.{table_name} ORDER BY embedding ").format(
161+
table_name=sql.Identifier(self.table_name)
162+
),
163+
sql.SQL(search_params["metric_fun_op"]),
164+
sql.SQL(" %s::vector LIMIT %s::int"),
165+
]
166+
)
167+
168+
log.debug(f"Unfiltered search query={self._unfiltered_search.as_string(self.conn)}")
169+
log.debug(f"Filtered search query={self._filtered_search.as_string(self.conn)}")
125170

126171
try:
127172
yield
@@ -234,7 +279,7 @@ def _create_index(self):
234279
options.append(
235280
sql.SQL("{option_name} = {val}").format(
236281
option_name=sql.Identifier(option_name),
237-
val=sql.Identifier(str(option_val)),
282+
val=sql.Literal(option_val),
238283
),
239284
)
240285

@@ -314,16 +359,39 @@ def search_embedding(
314359
assert self.conn is not None, "Connection is not initialized"
315360
assert self.cursor is not None, "Cursor is not initialized"
316361

362+
search_params = self.case_config.search_param()
363+
is_reranking = search_params.get("reranking", False)
364+
317365
q = np.asarray(query)
318366
if filters:
319367
gt = filters.get("id")
368+
if is_reranking:
369+
result = self.cursor.execute(
370+
self._filtered_search,
371+
(gt, q, q, k),
372+
prepare=True,
373+
binary=True,
374+
)
375+
else:
376+
result = self.cursor.execute(
377+
self._filtered_search,
378+
(gt, q, k),
379+
prepare=True,
380+
binary=True,
381+
)
382+
elif is_reranking:
320383
result = self.cursor.execute(
321-
self._filtered_search,
322-
(gt, q, k),
384+
self._unfiltered_search,
385+
(q, q, k),
323386
prepare=True,
324387
binary=True,
325388
)
326389
else:
327-
result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True)
390+
result = self.cursor.execute(
391+
self._unfiltered_search,
392+
(q, k),
393+
prepare=True,
394+
binary=True,
395+
)
328396

329397
return [int(i[0]) for i in result.fetchall()]

vectordb_bench/frontend/config/dbCaseConfigs.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,58 @@ class CaseConfigInput(BaseModel):
423423
},
424424
)
425425

426-
CaseConfigParamInput_max_neighbors = CaseConfigInput(
426+
CaseConfigParamInput_reranking_PgDiskANN = CaseConfigInput(
427+
label=CaseConfigParamType.reranking,
428+
inputType=InputType.Bool,
429+
displayLabel="Enable Reranking",
430+
inputHelp="Enable if you want to use reranking while performing \
431+
similarity search with PQ",
432+
inputConfig={
433+
"value": False,
434+
},
435+
)
436+
437+
CaseConfigParamInput_quantized_fetch_limit_PgDiskANN = CaseConfigInput(
438+
label=CaseConfigParamType.quantized_fetch_limit,
439+
displayLabel="Quantized Fetch Limit",
440+
inputHelp="Limit top-k vectors using the quantized vector comparison",
441+
inputType=InputType.Number,
442+
inputConfig={
443+
"min": 20,
444+
"max": 1000,
445+
"value": 200,
446+
},
447+
isDisplayed=lambda config: config.get(CaseConfigParamType.reranking, False),
448+
)
449+
450+
CaseConfigParamInput_pq_param_num_chunks_PgDiskANN = CaseConfigInput(
451+
label=CaseConfigParamType.pq_param_num_chunks,
452+
displayLabel="pq_param_num_chunks",
453+
inputHelp="Number of chunks for product quantization (Defaults to 0). 0 means it is determined automatically, based on embedding dimensions.",
454+
inputType=InputType.Number,
455+
inputConfig={
456+
"min": 0,
457+
"max": 1028,
458+
"value": 0,
459+
},
460+
isDisplayed=lambda config: config.get(CaseConfigParamType.reranking, False),
461+
)
462+
463+
464+
CaseConfigParamInput_reranking_metric_PgDiskANN = CaseConfigInput(
465+
label=CaseConfigParamType.reranking_metric,
466+
displayLabel="Reranking Metric",
467+
inputType=InputType.Option,
468+
inputConfig={
469+
"options": [metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD", "DP"]],
470+
},
471+
isDisplayed=lambda config: config.get(CaseConfigParamType.reranking, False),
472+
)
473+
474+
475+
CaseConfigParamInput_max_neighbors_PgDiskANN = CaseConfigInput(
427476
label=CaseConfigParamType.max_neighbors,
477+
displayLabel="max_neighbors",
428478
inputType=InputType.Number,
429479
inputConfig={
430480
"min": 10,
@@ -456,6 +506,29 @@ class CaseConfigInput(BaseModel):
456506
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.DISKANN.value,
457507
)
458508

509+
CaseConfigParamInput_maintenance_work_mem_PgDiskANN = CaseConfigInput(
510+
label=CaseConfigParamType.maintenance_work_mem,
511+
inputHelp="Memory to use during index builds. Not to exceed the available free memory."
512+
"Specify in gigabytes. e.g. 8GB",
513+
inputType=InputType.Text,
514+
inputConfig={
515+
"value": "8GB",
516+
},
517+
)
518+
519+
CaseConfigParamInput_max_parallel_workers_PgDiskANN = CaseConfigInput(
520+
label=CaseConfigParamType.max_parallel_workers,
521+
displayLabel="Max parallel workers",
522+
inputHelp="Recommended value: (cpu cores - 1). This will set the parameters: max_parallel_maintenance_workers,"
523+
" max_parallel_workers & table(parallel_workers)",
524+
inputType=InputType.Number,
525+
inputConfig={
526+
"min": 0,
527+
"max": 1024,
528+
"value": 16,
529+
},
530+
)
531+
459532
CaseConfigParamInput_num_neighbors = CaseConfigInput(
460533
label=CaseConfigParamType.num_neighbors,
461534
inputType=InputType.Number,
@@ -1796,15 +1869,21 @@ class CaseConfigInput(BaseModel):
17961869

17971870
PgDiskANNLoadConfig = [
17981871
CaseConfigParamInput_IndexType_PgDiskANN,
1799-
CaseConfigParamInput_max_neighbors,
1872+
CaseConfigParamInput_max_neighbors_PgDiskANN,
18001873
CaseConfigParamInput_l_value_ib,
18011874
]
18021875

18031876
PgDiskANNPerformanceConfig = [
18041877
CaseConfigParamInput_IndexType_PgDiskANN,
1805-
CaseConfigParamInput_max_neighbors,
1878+
CaseConfigParamInput_reranking_PgDiskANN,
1879+
CaseConfigParamInput_max_neighbors_PgDiskANN,
18061880
CaseConfigParamInput_l_value_ib,
18071881
CaseConfigParamInput_l_value_is,
1882+
CaseConfigParamInput_maintenance_work_mem_PgDiskANN,
1883+
CaseConfigParamInput_max_parallel_workers_PgDiskANN,
1884+
CaseConfigParamInput_pq_param_num_chunks_PgDiskANN,
1885+
CaseConfigParamInput_quantized_fetch_limit_PgDiskANN,
1886+
CaseConfigParamInput_reranking_metric_PgDiskANN,
18081887
]
18091888

18101889

0 commit comments

Comments
 (0)