Skip to content
204 changes: 106 additions & 98 deletions vectordb_bench/backend/clients/pgdiskann/pgdiskann.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from pgvector.psycopg import register_vector
from psycopg import Connection, Cursor, sql

from vectordb_bench.backend.filter import Filter, FilterOp

from ..api import VectorDB
from .config import PgDiskANNConfigDict, PgDiskANNIndexConfig

Expand All @@ -19,11 +21,16 @@
class PgDiskANN(VectorDB):
"""Use psycopg instructions"""

supported_filter_types: list[FilterOp] = [
FilterOp.NonFilter,
FilterOp.NumGE,
FilterOp.StrEqual,
]

conn: psycopg.Connection[Any] | None = None
coursor: psycopg.Cursor[Any] | None = None
cursor: psycopg.Cursor[Any] | None = None

_filtered_search: sql.Composed
_unfiltered_search: sql.Composed
_search: sql.Composed

def __init__(
self,
Expand All @@ -32,13 +39,17 @@ def __init__(
db_case_config: PgDiskANNIndexConfig,
collection_name: str = "pg_diskann_collection",
drop_old: bool = False,
with_scalar_labels: bool = False,
**kwargs,
):
self.name = "PgDiskANN"
self.db_config = db_config
self.case_config = db_case_config
self.table_name = collection_name
self.dim = dim
self.with_scalar_labels = with_scalar_labels
self._scalar_label_field = "label"
self.where_clause = ""

self._index_name = "pgdiskann_index"
self._primary_field = "id"
Expand Down Expand Up @@ -86,83 +97,58 @@ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:

return conn, cursor

@contextmanager
def init(self) -> Generator[None, None, None]:
self.conn, self.cursor = self._create_connection(**self.db_config)

session_options: dict[str, Any] = self.case_config.session_param()

if len(session_options) > 0:
for setting_name, setting_val in session_options.items():
command = sql.SQL("SET {setting_name} = {setting_val};").format(
setting_name=sql.Identifier(setting_name), setting_val=sql.Literal(setting_val)
)
log.debug(command.as_string(self.cursor))
self.cursor.execute(command)
self.conn.commit()

def _generate_search_query(self) -> sql.Composed:
"""Generate search query with where_clause placeholder"""
search_params = self.case_config.search_param()

if search_params.get("reranking"):
# Reranking-enabled queries
self._filtered_search = sql.SQL("""
search_query = sql.SQL("""
SELECT i.id
FROM (
SELECT id, embedding
FROM public.{table_name}
WHERE id >= %s
{where_clause}
ORDER BY embedding {metric_fun_op} %s::vector
LIMIT {quantized_fetch_limit}::int
) i
ORDER BY i.embedding {reranking_metric_fun_op} %s::vector
LIMIT %s::int
""").format(
""").format(
table_name=sql.Identifier(self.table_name),
where_clause=sql.SQL(self.where_clause),
metric_fun_op=sql.SQL(search_params["metric_fun_op"]),
reranking_metric_fun_op=sql.SQL(search_params["reranking_metric_fun_op"]),
quantized_fetch_limit=sql.Literal(search_params["quantized_fetch_limit"]),
)

self._unfiltered_search = sql.SQL("""
SELECT i.id
FROM (
SELECT id, embedding
FROM public.{table_name}
ORDER BY embedding {metric_fun_op} %s::vector
LIMIT {quantized_fetch_limit}::int
) i
ORDER BY i.embedding {reranking_metric_fun_op} %s::vector
LIMIT %s::int
""").format(
table_name=sql.Identifier(self.table_name),
metric_fun_op=sql.SQL(search_params["metric_fun_op"]),
reranking_metric_fun_op=sql.SQL(search_params["reranking_metric_fun_op"]),
quantized_fetch_limit=sql.Literal(search_params["quantized_fetch_limit"]),
)

else:
self._filtered_search = sql.Composed(
search_query = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ",
).format(table_name=sql.Identifier(self.table_name)),
sql.SQL(search_params["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
]
)

self._unfiltered_search = sql.Composed(
[
sql.SQL("SELECT id FROM public.{table_name} ORDER BY embedding ").format(
table_name=sql.Identifier(self.table_name)
sql.SQL("SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding ").format(
table_name=sql.Identifier(self.table_name),
where_clause=sql.SQL(self.where_clause),
),
sql.SQL(search_params["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
]
)

log.debug(f"Unfiltered search query={self._unfiltered_search.as_string(self.conn)}")
log.debug(f"Filtered search query={self._filtered_search.as_string(self.conn)}")
return search_query

@contextmanager
def init(self) -> Generator[None, None, None]:
self.conn, self.cursor = self._create_connection(**self.db_config)

session_options: dict[str, Any] = self.case_config.session_param()

if len(session_options) > 0:
for setting_name, setting_val in session_options.items():
command = sql.SQL("SET {setting_name} = {setting_val};").format(
setting_name=sql.Identifier(setting_name),
setting_val=sql.Literal(setting_val),
)
log.debug(command.as_string(self.cursor))
self.cursor.execute(command)
self.conn.commit()

try:
yield
Expand Down Expand Up @@ -281,12 +267,10 @@ def _create_index(self):

with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())

index_create_sql = sql.SQL(
"""
index_create_sql = sql.SQL("""
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
USING {index_type} (embedding {embedding_metric})
""",
).format(
""").format(
index_name=sql.Identifier(self._index_name),
table_name=sql.Identifier(self.table_name),
index_type=sql.Identifier(index_param["index_type"].lower()),
Expand All @@ -304,11 +288,36 @@ def _create_table(self, dim: int):
try:
log.info(f"{self.name} client create table : {self.table_name}")

if self.with_scalar_labels:
self.cursor.execute(
sql.SQL("""
CREATE TABLE IF NOT EXISTS public.{table_name}
({primary_field} BIGINT PRIMARY KEY, embedding vector({dim}), {label_field} VARCHAR(64));
""").format(
table_name=sql.Identifier(self.table_name),
dim=dim,
primary_field=sql.Identifier(self._primary_field),
label_field=sql.Identifier(self._scalar_label_field),
),
)
else:
self.cursor.execute(
sql.SQL("""
CREATE TABLE IF NOT EXISTS public.{table_name}
({primary_field} BIGINT PRIMARY KEY, embedding vector({dim}));
""").format(
table_name=sql.Identifier(self.table_name),
dim=dim,
primary_field=sql.Identifier(self._primary_field),
),
)

self.cursor.execute(
sql.SQL(
"CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
).format(table_name=sql.Identifier(self.table_name), dim=dim),
sql.SQL("ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;").format(
table_name=sql.Identifier(self.table_name)
),
)

self.conn.commit()
except Exception as e:
log.warning(f"Failed to create pgdiskann table: {self.table_name} error: {e}")
Expand All @@ -318,11 +327,15 @@ def insert_embeddings(
self,
embeddings: list[list[float]],
metadata: list[int],
labels_data: list[str] | None = None,
**kwargs: Any,
) -> tuple[int, Exception | None]:
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

if self.with_scalar_labels:
assert labels_data is not None, "labels_data should be provided if with_scalar_labels is set to True"

try:
metadata_arr = np.array(metadata)
embeddings_arr = np.array(embeddings)
Expand All @@ -332,9 +345,14 @@ def insert_embeddings(
table_name=sql.Identifier(self.table_name),
),
) as copy:
copy.set_types(["bigint", "vector"])
for i, row in enumerate(metadata_arr):
copy.write_row((row, embeddings_arr[i]))
if self.with_scalar_labels:
copy.set_types(["bigint", "vector", "varchar"])
for i, row in enumerate(metadata_arr):
copy.write_row((row, embeddings_arr[i], labels_data[i]))
else:
copy.set_types(["bigint", "vector"])
for i, row in enumerate(metadata_arr):
copy.write_row((row, embeddings_arr[i]))
self.conn.commit()

if kwargs.get("last_batch"):
Expand All @@ -345,49 +363,39 @@ def insert_embeddings(
log.warning(f"Failed to insert data into table ({self.table_name}), error: {e}")
return 0, e

def prepare_filter(self, filters: Filter):
"""Prepare filter - builds where_clause"""
if filters.type == FilterOp.NonFilter:
self.where_clause = ""
elif filters.type == FilterOp.NumGE:
self.where_clause = f"WHERE {self._primary_field} >= {filters.int_value}"
elif filters.type == FilterOp.StrEqual:
self.where_clause = f"WHERE {self._scalar_label_field} = '{filters.label_value}'"
else:
msg = f"Not support Filter for PgDiskANN - {filters}"
raise ValueError(msg)

self._search = self._generate_search_query()
log.debug(f"Search query={self._search.as_string(self.conn)}")

def search_embedding(
self,
query: list[float],
k: int = 100,
filters: dict | None = None,
timeout: int | None = None,
**kwargs: Any,
) -> list[int]:
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

search_params = self.case_config.search_param()
is_reranking = search_params.get("reranking", False)

q = np.asarray(query)
if filters:
gt = filters.get("id")
if is_reranking:
result = self.cursor.execute(
self._filtered_search,
(gt, q, q, k),
prepare=True,
binary=True,
)
else:
result = self.cursor.execute(
self._filtered_search,
(gt, q, k),
prepare=True,
binary=True,
)
elif is_reranking:
result = self.cursor.execute(
self._unfiltered_search,
(q, q, k),
prepare=True,
binary=True,
)
else:
result = self.cursor.execute(
self._unfiltered_search,
(q, k),
prepare=True,
binary=True,
)

result = self.cursor.execute(
self._search,
(q, q, k) if search_params.get("reranking", False) else (q, k),
prepare=True,
binary=True,
)

return [int(i[0]) for i in result.fetchall()]
Loading