diff --git a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py index 5f069ace5..46e8fabd4 100644 --- a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +++ b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py @@ -10,6 +10,8 @@ from pgvector.psycopg import register_vector from psycopg import Connection, Cursor, sql +from vectordb_bench.backend.filter import Filter, FilterOp + from ..api import VectorDB from .config import PgDiskANNConfigDict, PgDiskANNIndexConfig @@ -19,11 +21,16 @@ class PgDiskANN(VectorDB): """Use psycopg instructions""" + supported_filter_types: list[FilterOp] = [ + FilterOp.NonFilter, + FilterOp.NumGE, + FilterOp.StrEqual, + ] + conn: psycopg.Connection[Any] | None = None - coursor: psycopg.Cursor[Any] | None = None + cursor: psycopg.Cursor[Any] | None = None - _filtered_search: sql.Composed - _unfiltered_search: sql.Composed + _search: sql.Composed def __init__( self, @@ -32,6 +39,7 @@ def __init__( db_case_config: PgDiskANNIndexConfig, collection_name: str = "pg_diskann_collection", drop_old: bool = False, + with_scalar_labels: bool = False, **kwargs, ): self.name = "PgDiskANN" @@ -39,6 +47,9 @@ def __init__( self.case_config = db_case_config self.table_name = collection_name self.dim = dim + self.with_scalar_labels = with_scalar_labels + self._scalar_label_field = "label" + self.where_clause = "" self._index_name = "pgdiskann_index" self._primary_field = "id" @@ -86,83 +97,58 @@ def _create_connection(**kwargs) -> tuple[Connection, Cursor]: return conn, cursor - @contextmanager - def init(self) -> Generator[None, None, None]: - self.conn, self.cursor = self._create_connection(**self.db_config) - - session_options: dict[str, Any] = self.case_config.session_param() - - if len(session_options) > 0: - for setting_name, setting_val in session_options.items(): - command = sql.SQL("SET {setting_name} = {setting_val};").format( - setting_name=sql.Identifier(setting_name), setting_val=sql.Literal(setting_val) - ) - log.debug(command.as_string(self.cursor)) - self.cursor.execute(command) - self.conn.commit() - + def _generate_search_query(self) -> sql.Composed: + """Generate search query with where_clause placeholder""" search_params = self.case_config.search_param() if search_params.get("reranking"): - # Reranking-enabled queries - self._filtered_search = sql.SQL(""" + search_query = sql.SQL(""" SELECT i.id FROM ( SELECT id, embedding FROM public.{table_name} - WHERE id >= %s + {where_clause} ORDER BY embedding {metric_fun_op} %s::vector LIMIT {quantized_fetch_limit}::int ) i ORDER BY i.embedding {reranking_metric_fun_op} %s::vector LIMIT %s::int - """).format( + """).format( table_name=sql.Identifier(self.table_name), + where_clause=sql.SQL(self.where_clause), metric_fun_op=sql.SQL(search_params["metric_fun_op"]), reranking_metric_fun_op=sql.SQL(search_params["reranking_metric_fun_op"]), quantized_fetch_limit=sql.Literal(search_params["quantized_fetch_limit"]), ) - - self._unfiltered_search = sql.SQL(""" - SELECT i.id - FROM ( - SELECT id, embedding - FROM public.{table_name} - ORDER BY embedding {metric_fun_op} %s::vector - LIMIT {quantized_fetch_limit}::int - ) i - ORDER BY i.embedding {reranking_metric_fun_op} %s::vector - LIMIT %s::int - """).format( - table_name=sql.Identifier(self.table_name), - metric_fun_op=sql.SQL(search_params["metric_fun_op"]), - reranking_metric_fun_op=sql.SQL(search_params["reranking_metric_fun_op"]), - quantized_fetch_limit=sql.Literal(search_params["quantized_fetch_limit"]), - ) - else: - self._filtered_search = sql.Composed( + search_query = sql.Composed( [ - sql.SQL( - "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ", - ).format(table_name=sql.Identifier(self.table_name)), - sql.SQL(search_params["metric_fun_op"]), - sql.SQL(" %s::vector LIMIT %s::int"), - ] - ) - - self._unfiltered_search = sql.Composed( - [ - sql.SQL("SELECT id FROM public.{table_name} ORDER BY embedding ").format( - table_name=sql.Identifier(self.table_name) + sql.SQL("SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding ").format( + table_name=sql.Identifier(self.table_name), + where_clause=sql.SQL(self.where_clause), ), sql.SQL(search_params["metric_fun_op"]), sql.SQL(" %s::vector LIMIT %s::int"), ] ) - log.debug(f"Unfiltered search query={self._unfiltered_search.as_string(self.conn)}") - log.debug(f"Filtered search query={self._filtered_search.as_string(self.conn)}") + return search_query + + @contextmanager + def init(self) -> Generator[None, None, None]: + self.conn, self.cursor = self._create_connection(**self.db_config) + + session_options: dict[str, Any] = self.case_config.session_param() + + if len(session_options) > 0: + for setting_name, setting_val in session_options.items(): + command = sql.SQL("SET {setting_name} = {setting_val};").format( + setting_name=sql.Identifier(setting_name), + setting_val=sql.Literal(setting_val), + ) + log.debug(command.as_string(self.cursor)) + self.cursor.execute(command) + self.conn.commit() try: yield @@ -281,12 +267,10 @@ def _create_index(self): with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(()) - index_create_sql = sql.SQL( - """ + index_create_sql = sql.SQL(""" CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} USING {index_type} (embedding {embedding_metric}) - """, - ).format( + """).format( index_name=sql.Identifier(self._index_name), table_name=sql.Identifier(self.table_name), index_type=sql.Identifier(index_param["index_type"].lower()), @@ -304,11 +288,36 @@ def _create_table(self, dim: int): try: log.info(f"{self.name} client create table : {self.table_name}") + if self.with_scalar_labels: + self.cursor.execute( + sql.SQL(""" + CREATE TABLE IF NOT EXISTS public.{table_name} + ({primary_field} BIGINT PRIMARY KEY, embedding vector({dim}), {label_field} VARCHAR(64)); + """).format( + table_name=sql.Identifier(self.table_name), + dim=dim, + primary_field=sql.Identifier(self._primary_field), + label_field=sql.Identifier(self._scalar_label_field), + ), + ) + else: + self.cursor.execute( + sql.SQL(""" + CREATE TABLE IF NOT EXISTS public.{table_name} + ({primary_field} BIGINT PRIMARY KEY, embedding vector({dim})); + """).format( + table_name=sql.Identifier(self.table_name), + dim=dim, + primary_field=sql.Identifier(self._primary_field), + ), + ) + self.cursor.execute( - sql.SQL( - "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));", - ).format(table_name=sql.Identifier(self.table_name), dim=dim), + sql.SQL("ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;").format( + table_name=sql.Identifier(self.table_name) + ), ) + self.conn.commit() except Exception as e: log.warning(f"Failed to create pgdiskann table: {self.table_name} error: {e}") @@ -318,11 +327,15 @@ def insert_embeddings( self, embeddings: list[list[float]], metadata: list[int], + labels_data: list[str] | None = None, **kwargs: Any, ) -> tuple[int, Exception | None]: assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + if self.with_scalar_labels: + assert labels_data is not None, "labels_data should be provided if with_scalar_labels is set to True" + try: metadata_arr = np.array(metadata) embeddings_arr = np.array(embeddings) @@ -332,9 +345,14 @@ def insert_embeddings( table_name=sql.Identifier(self.table_name), ), ) as copy: - copy.set_types(["bigint", "vector"]) - for i, row in enumerate(metadata_arr): - copy.write_row((row, embeddings_arr[i])) + if self.with_scalar_labels: + copy.set_types(["bigint", "vector", "varchar"]) + for i, row in enumerate(metadata_arr): + copy.write_row((row, embeddings_arr[i], labels_data[i])) + else: + copy.set_types(["bigint", "vector"]) + for i, row in enumerate(metadata_arr): + copy.write_row((row, embeddings_arr[i])) self.conn.commit() if kwargs.get("last_batch"): @@ -345,49 +363,39 @@ def insert_embeddings( log.warning(f"Failed to insert data into table ({self.table_name}), error: {e}") return 0, e + def prepare_filter(self, filters: Filter): + """Prepare filter - builds where_clause""" + if filters.type == FilterOp.NonFilter: + self.where_clause = "" + elif filters.type == FilterOp.NumGE: + self.where_clause = f"WHERE {self._primary_field} >= {filters.int_value}" + elif filters.type == FilterOp.StrEqual: + self.where_clause = f"WHERE {self._scalar_label_field} = '{filters.label_value}'" + else: + msg = f"Not support Filter for PgDiskANN - {filters}" + raise ValueError(msg) + + self._search = self._generate_search_query() + log.debug(f"Search query={self._search.as_string(self.conn)}") + def search_embedding( self, query: list[float], k: int = 100, - filters: dict | None = None, timeout: int | None = None, + **kwargs: Any, ) -> list[int]: assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" search_params = self.case_config.search_param() - is_reranking = search_params.get("reranking", False) - q = np.asarray(query) - if filters: - gt = filters.get("id") - if is_reranking: - result = self.cursor.execute( - self._filtered_search, - (gt, q, q, k), - prepare=True, - binary=True, - ) - else: - result = self.cursor.execute( - self._filtered_search, - (gt, q, k), - prepare=True, - binary=True, - ) - elif is_reranking: - result = self.cursor.execute( - self._unfiltered_search, - (q, q, k), - prepare=True, - binary=True, - ) - else: - result = self.cursor.execute( - self._unfiltered_search, - (q, k), - prepare=True, - binary=True, - ) + + result = self.cursor.execute( + self._search, + (q, q, k) if search_params.get("reranking", False) else (q, k), + prepare=True, + binary=True, + ) return [int(i[0]) for i in result.fetchall()]