From a7169c14b83957af262bd569f3652dab10d84b41 Mon Sep 17 00:00:00 2001 From: Eesha Faisal Date: Tue, 3 Feb 2026 15:05:15 +0500 Subject: [PATCH 1/7] Add label filtering support to pgdiskann client --- .../backend/clients/pgdiskann/pgdiskann.py | 163 +++++++++++++++--- 1 file changed, 135 insertions(+), 28 deletions(-) diff --git a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py index 5f069ace5..4872bb2f8 100644 --- a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +++ b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py @@ -10,6 +10,7 @@ from pgvector.psycopg import register_vector from psycopg import Connection, Cursor, sql +from vectordb_bench.backend.filter import Filter, FilterOp from ..api import VectorDB from .config import PgDiskANNConfigDict, PgDiskANNIndexConfig @@ -19,8 +20,14 @@ class PgDiskANN(VectorDB): """Use psycopg instructions""" + supported_filter_types: list[FilterOp] = [ + FilterOp.NonFilter, + FilterOp.NumGE, + FilterOp.StrEqual, + ] + conn: psycopg.Connection[Any] | None = None - coursor: psycopg.Cursor[Any] | None = None + cursor: psycopg.Cursor[Any] | None = None _filtered_search: sql.Composed _unfiltered_search: sql.Composed @@ -32,6 +39,7 @@ def __init__( db_case_config: PgDiskANNIndexConfig, collection_name: str = "pg_diskann_collection", drop_old: bool = False, + with_scalar_labels: bool = False, **kwargs, ): self.name = "PgDiskANN" @@ -39,6 +47,10 @@ def __init__( self.case_config = db_case_config self.table_name = collection_name self.dim = dim + self.with_scalar_labels = with_scalar_labels + self._scalar_label_field = "label" + self.filter_op = None + self.filter_value = None self._index_name = "pgdiskann_index" self._primary_field = "id" @@ -304,11 +316,23 @@ def _create_table(self, dim: int): try: log.info(f"{self.name} client create table : {self.table_name}") - self.cursor.execute( - sql.SQL( - "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));", - ).format(table_name=sql.Identifier(self.table_name), dim=dim), - ) + if self.with_scalar_labels: + # Create table WITH label column + self.cursor.execute( + sql.SQL( + """ + CREATE TABLE IF NOT EXISTS public.{table_name} + (id BIGINT PRIMARY KEY, embedding vector({dim}), label VARCHAR(64)); + """, + ).format(table_name=sql.Identifier(self.table_name), dim=dim), + ) + else: + # Create table WITHOUT label column (existing behavior) + self.cursor.execute( + sql.SQL( + "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));", + ).format(table_name=sql.Identifier(self.table_name), dim=dim), + ) self.conn.commit() except Exception as e: log.warning(f"Failed to create pgdiskann table: {self.table_name} error: {e}") @@ -318,11 +342,15 @@ def insert_embeddings( self, embeddings: list[list[float]], metadata: list[int], + labels_data: list[str] | None = None, **kwargs: Any, ) -> tuple[int, Exception | None]: assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + if self.with_scalar_labels: + assert labels_data is not None, "labels_data should be provided if with_scalar_labels is set to True" + try: metadata_arr = np.array(metadata) embeddings_arr = np.array(embeddings) @@ -332,9 +360,14 @@ def insert_embeddings( table_name=sql.Identifier(self.table_name), ), ) as copy: - copy.set_types(["bigint", "vector"]) - for i, row in enumerate(metadata_arr): - copy.write_row((row, embeddings_arr[i])) + if self.with_scalar_labels: + copy.set_types(["bigint", "vector", "varchar"]) + for i, row in enumerate(metadata_arr): + copy.write_row((row, embeddings_arr[i], labels_data[i])) + else: + copy.set_types(["bigint", "vector"]) + for i, row in enumerate(metadata_arr): + copy.write_row((row, embeddings_arr[i])) self.conn.commit() if kwargs.get("last_batch"): @@ -345,49 +378,123 @@ def insert_embeddings( log.warning(f"Failed to insert data into table ({self.table_name}), error: {e}") return 0, e + def prepare_filter(self, filters: Filter): + """Prepare filter for label or integer-based filtering""" + from vectordb_bench.backend.filter import FilterOp + + if filters.type == FilterOp.NonFilter: + self.filter_op = FilterOp.NonFilter + self.filter_value = None + + elif filters.type == FilterOp.NumGE: + # Integer filtering: WHERE id >= X + self.filter_op = FilterOp.NumGE + self.filter_value = filters.int_value + + elif filters.type == FilterOp.StrEqual: + # Label filtering: WHERE label = 'label_1p' + self.filter_op = FilterOp.StrEqual + self.filter_value = filters.label_value + else: + msg = f"Not support Filter for PgDiskANN - {filters}" + raise ValueError(msg) + + def search_embedding( self, query: list[float], k: int = 100, - filters: dict | None = None, + filters: dict | None = None, timeout: int | None = None, ) -> list[int]: assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + from vectordb_bench.backend.filter import FilterOp + search_params = self.case_config.search_param() is_reranking = search_params.get("reranking", False) - + q = np.asarray(query) - if filters: - gt = filters.get("id") + + # Build the appropriate query based on filter_op + if self.filter_op == FilterOp.StrEqual: + # Label filtering: e.g. WHERE label = 'label_1p' + if is_reranking: + query_sql = sql.SQL(""" + SELECT i.id + FROM ( + SELECT id, embedding + FROM public.{table_name} + WHERE {label_field} = %s + ORDER BY embedding {metric_fun_op} %s::vector + LIMIT {quantized_fetch_limit}::int + ) i + ORDER BY i.embedding {reranking_metric_fun_op} %s::vector + LIMIT %s::int + """).format( + table_name=sql.Identifier(self.table_name), + label_field=sql.Identifier(self._scalar_label_field), + metric_fun_op=sql.SQL(search_params["metric_fun_op"]), + reranking_metric_fun_op=sql.SQL(search_params["reranking_metric_fun_op"]), + quantized_fetch_limit=sql.Literal(search_params["quantized_fetch_limit"]), + ) + result = self.cursor.execute( + query_sql, + (self.filter_value, q, q, k), + prepare=True, + binary=True, + ) + else: + query_sql = sql.Composed([ + sql.SQL( + "SELECT id FROM public.{table_name} WHERE {label_field} = %s ORDER BY embedding ", + ).format( + table_name=sql.Identifier(self.table_name), + label_field=sql.Identifier(self._scalar_label_field), + ), + sql.SQL(search_params["metric_fun_op"]), + sql.SQL(" %s::vector LIMIT %s::int"), + ]) + result = self.cursor.execute( + query_sql, + (self.filter_value, q, k), + prepare=True, + binary=True, + ) + + elif self.filter_op == FilterOp.NumGE: + # Integer filtering: WHERE id >= X (existing behavior) if is_reranking: result = self.cursor.execute( self._filtered_search, - (gt, q, q, k), + (self.filter_value, q, q, k), prepare=True, binary=True, ) else: result = self.cursor.execute( self._filtered_search, - (gt, q, k), + (self.filter_value, q, k), prepare=True, binary=True, ) - elif is_reranking: - result = self.cursor.execute( - self._unfiltered_search, - (q, q, k), - prepare=True, - binary=True, - ) + else: - result = self.cursor.execute( - self._unfiltered_search, - (q, k), - prepare=True, - binary=True, - ) + # No filtering (existing behavior) + if is_reranking: + result = self.cursor.execute( + self._unfiltered_search, + (q, q, k), + prepare=True, + binary=True, + ) + else: + result = self.cursor.execute( + self._unfiltered_search, + (q, k), + prepare=True, + binary=True, + ) return [int(i[0]) for i in result.fetchall()] From 25e883886967122a1ed12eabe57f70083907a4c4 Mon Sep 17 00:00:00 2001 From: Eesha Faisal Date: Fri, 6 Feb 2026 15:58:57 +0500 Subject: [PATCH 2/7] Refactor pgdiskann filtering logic --- .../backend/clients/pgdiskann/pgdiskann.py | 280 +++++++----------- 1 file changed, 113 insertions(+), 167 deletions(-) diff --git a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py index 4872bb2f8..23bde85b4 100644 --- a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +++ b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py @@ -21,16 +21,15 @@ class PgDiskANN(VectorDB): """Use psycopg instructions""" supported_filter_types: list[FilterOp] = [ - FilterOp.NonFilter, - FilterOp.NumGE, - FilterOp.StrEqual, + FilterOp.NonFilter, + FilterOp.NumGE, + FilterOp.StrEqual, ] conn: psycopg.Connection[Any] | None = None cursor: psycopg.Cursor[Any] | None = None - _filtered_search: sql.Composed - _unfiltered_search: sql.Composed + _search: sql.Composed def __init__( self, @@ -49,8 +48,7 @@ def __init__( self.dim = dim self.with_scalar_labels = with_scalar_labels self._scalar_label_field = "label" - self.filter_op = None - self.filter_value = None + self.where_clause = "" self._index_name = "pgdiskann_index" self._primary_field = "id" @@ -84,6 +82,35 @@ def __init__( self.cursor = None self.conn = None + def get_size_info(self): + try: + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + log.info(f"{self.name} client get size info.") + + size_sql = sql.SQL( + "SELECT pg_size_pretty(pg_table_size('{table_name}')) as table_size, pg_size_pretty(pg_table_size('{index_name}')) as index_size;" + ).format( + table_name=sql.Identifier(self.table_name), + index_name=sql.Identifier(self._index_name), + ) + log.debug(size_sql.as_string(self.cursor)) + self.cursor.execute(size_sql) + self.conn.commit() + result = self.cursor.fetchone() + + if result: + table_size = result[0] + index_size = result[1] + log.info(f"Table Size: {table_size}, Index Size: {index_size}") + return (table_size, index_size) + else: + log.error("No results returned from the query.") + return (0, 0) + except Exception as e: + log.warning("Failed to fetch table and index information") + return (0, 0) + @staticmethod def _create_connection(**kwargs) -> tuple[Connection, Cursor]: conn = psycopg.connect(**kwargs) @@ -98,83 +125,62 @@ def _create_connection(**kwargs) -> tuple[Connection, Cursor]: return conn, cursor - @contextmanager - def init(self) -> Generator[None, None, None]: - self.conn, self.cursor = self._create_connection(**self.db_config) - - session_options: dict[str, Any] = self.case_config.session_param() - - if len(session_options) > 0: - for setting_name, setting_val in session_options.items(): - command = sql.SQL("SET {setting_name} = {setting_val};").format( - setting_name=sql.Identifier(setting_name), setting_val=sql.Literal(setting_val) - ) - log.debug(command.as_string(self.cursor)) - self.cursor.execute(command) - self.conn.commit() - + def _generate_search_query(self) -> sql.Composed: + """Generate search query with where_clause placeholder""" search_params = self.case_config.search_param() if search_params.get("reranking"): - # Reranking-enabled queries - self._filtered_search = sql.SQL(""" - SELECT i.id - FROM ( - SELECT id, embedding - FROM public.{table_name} - WHERE id >= %s - ORDER BY embedding {metric_fun_op} %s::vector - LIMIT {quantized_fetch_limit}::int - ) i - ORDER BY i.embedding {reranking_metric_fun_op} %s::vector - LIMIT %s::int - """).format( - table_name=sql.Identifier(self.table_name), - metric_fun_op=sql.SQL(search_params["metric_fun_op"]), - reranking_metric_fun_op=sql.SQL(search_params["reranking_metric_fun_op"]), - quantized_fetch_limit=sql.Literal(search_params["quantized_fetch_limit"]), - ) - - self._unfiltered_search = sql.SQL(""" + search_query = sql.SQL( + """ SELECT i.id FROM ( SELECT id, embedding FROM public.{table_name} + {where_clause} ORDER BY embedding {metric_fun_op} %s::vector LIMIT {quantized_fetch_limit}::int ) i ORDER BY i.embedding {reranking_metric_fun_op} %s::vector LIMIT %s::int - """).format( + """ + ).format( table_name=sql.Identifier(self.table_name), + where_clause=sql.SQL(self.where_clause), metric_fun_op=sql.SQL(search_params["metric_fun_op"]), reranking_metric_fun_op=sql.SQL(search_params["reranking_metric_fun_op"]), quantized_fetch_limit=sql.Literal(search_params["quantized_fetch_limit"]), ) - else: - self._filtered_search = sql.Composed( + search_query = sql.Composed( [ sql.SQL( - "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ", - ).format(table_name=sql.Identifier(self.table_name)), - sql.SQL(search_params["metric_fun_op"]), - sql.SQL(" %s::vector LIMIT %s::int"), - ] - ) - - self._unfiltered_search = sql.Composed( - [ - sql.SQL("SELECT id FROM public.{table_name} ORDER BY embedding ").format( - table_name=sql.Identifier(self.table_name) + "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding " + ).format( + table_name=sql.Identifier(self.table_name), + where_clause=sql.SQL(self.where_clause), ), sql.SQL(search_params["metric_fun_op"]), sql.SQL(" %s::vector LIMIT %s::int"), ] ) - log.debug(f"Unfiltered search query={self._unfiltered_search.as_string(self.conn)}") - log.debug(f"Filtered search query={self._filtered_search.as_string(self.conn)}") + return search_query + + @contextmanager + def init(self) -> Generator[None, None, None]: + self.conn, self.cursor = self._create_connection(**self.db_config) + + session_options: dict[str, Any] = self.case_config.session_param() + + if len(session_options) > 0: + for setting_name, setting_val in session_options.items(): + command = sql.SQL("SET {setting_name} = {setting_val};").format( + setting_name=sql.Identifier(setting_name), + setting_val=sql.Literal(setting_val), + ) + log.debug(command.as_string(self.cursor)) + self.cursor.execute(command) + self.conn.commit() try: yield @@ -291,13 +297,17 @@ def _create_index(self): ), ) - with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(()) + with_clause = ( + sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) + if any(options) + else sql.Composed(()) + ) index_create_sql = sql.SQL( """ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} USING {index_type} (embedding {embedding_metric}) - """, + """ ).format( index_name=sql.Identifier(self._index_name), table_name=sql.Identifier(self.table_name), @@ -317,22 +327,39 @@ def _create_table(self, dim: int): log.info(f"{self.name} client create table : {self.table_name}") if self.with_scalar_labels: - # Create table WITH label column self.cursor.execute( sql.SQL( """ CREATE TABLE IF NOT EXISTS public.{table_name} - (id BIGINT PRIMARY KEY, embedding vector({dim}), label VARCHAR(64)); - """, - ).format(table_name=sql.Identifier(self.table_name), dim=dim), + ({primary_field} BIGINT PRIMARY KEY, embedding vector({dim}), {label_field} VARCHAR(64)); + """ + ).format( + table_name=sql.Identifier(self.table_name), + dim=dim, + primary_field=sql.Identifier(self._primary_field), + label_field=sql.Identifier(self._scalar_label_field), + ), ) else: - # Create table WITHOUT label column (existing behavior) self.cursor.execute( sql.SQL( - "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));", - ).format(table_name=sql.Identifier(self.table_name), dim=dim), + """ + CREATE TABLE IF NOT EXISTS public.{table_name} + ({primary_field} BIGINT PRIMARY KEY, embedding vector({dim})); + """ + ).format( + table_name=sql.Identifier(self.table_name), + dim=dim, + primary_field=sql.Identifier(self._primary_field), + ), ) + + self.cursor.execute( + sql.SQL( + "ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;" + ).format(table_name=sql.Identifier(self.table_name)), + ) + self.conn.commit() except Exception as e: log.warning(f"Failed to create pgdiskann table: {self.table_name} error: {e}") @@ -342,14 +369,16 @@ def insert_embeddings( self, embeddings: list[list[float]], metadata: list[int], - labels_data: list[str] | None = None, + labels_data: list[str] | None = None, **kwargs: Any, ) -> tuple[int, Exception | None]: assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" if self.with_scalar_labels: - assert labels_data is not None, "labels_data should be provided if with_scalar_labels is set to True" + assert ( + labels_data is not None + ), "labels_data should be provided if with_scalar_labels is set to True" try: metadata_arr = np.array(metadata) @@ -379,122 +408,39 @@ def insert_embeddings( return 0, e def prepare_filter(self, filters: Filter): - """Prepare filter for label or integer-based filtering""" - from vectordb_bench.backend.filter import FilterOp - + """Prepare filter - builds where_clause""" if filters.type == FilterOp.NonFilter: - self.filter_op = FilterOp.NonFilter - self.filter_value = None - + self.where_clause = "" elif filters.type == FilterOp.NumGE: - # Integer filtering: WHERE id >= X - self.filter_op = FilterOp.NumGE - self.filter_value = filters.int_value - + self.where_clause = f"WHERE {self._primary_field} >= {filters.int_value}" elif filters.type == FilterOp.StrEqual: - # Label filtering: WHERE label = 'label_1p' - self.filter_op = FilterOp.StrEqual - self.filter_value = filters.label_value + self.where_clause = f"WHERE {self._scalar_label_field} = '{filters.label_value}'" else: msg = f"Not support Filter for PgDiskANN - {filters}" raise ValueError(msg) + self._search = self._generate_search_query() + log.debug(f"Search query={self._search.as_string(self.conn)}") def search_embedding( self, query: list[float], k: int = 100, - filters: dict | None = None, timeout: int | None = None, + **kwargs: Any, ) -> list[int]: assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" - from vectordb_bench.backend.filter import FilterOp - search_params = self.case_config.search_param() - is_reranking = search_params.get("reranking", False) - q = np.asarray(query) - # Build the appropriate query based on filter_op - if self.filter_op == FilterOp.StrEqual: - # Label filtering: e.g. WHERE label = 'label_1p' - if is_reranking: - query_sql = sql.SQL(""" - SELECT i.id - FROM ( - SELECT id, embedding - FROM public.{table_name} - WHERE {label_field} = %s - ORDER BY embedding {metric_fun_op} %s::vector - LIMIT {quantized_fetch_limit}::int - ) i - ORDER BY i.embedding {reranking_metric_fun_op} %s::vector - LIMIT %s::int - """).format( - table_name=sql.Identifier(self.table_name), - label_field=sql.Identifier(self._scalar_label_field), - metric_fun_op=sql.SQL(search_params["metric_fun_op"]), - reranking_metric_fun_op=sql.SQL(search_params["reranking_metric_fun_op"]), - quantized_fetch_limit=sql.Literal(search_params["quantized_fetch_limit"]), - ) - result = self.cursor.execute( - query_sql, - (self.filter_value, q, q, k), - prepare=True, - binary=True, - ) - else: - query_sql = sql.Composed([ - sql.SQL( - "SELECT id FROM public.{table_name} WHERE {label_field} = %s ORDER BY embedding ", - ).format( - table_name=sql.Identifier(self.table_name), - label_field=sql.Identifier(self._scalar_label_field), - ), - sql.SQL(search_params["metric_fun_op"]), - sql.SQL(" %s::vector LIMIT %s::int"), - ]) - result = self.cursor.execute( - query_sql, - (self.filter_value, q, k), - prepare=True, - binary=True, - ) - - elif self.filter_op == FilterOp.NumGE: - # Integer filtering: WHERE id >= X (existing behavior) - if is_reranking: - result = self.cursor.execute( - self._filtered_search, - (self.filter_value, q, q, k), - prepare=True, - binary=True, - ) - else: - result = self.cursor.execute( - self._filtered_search, - (self.filter_value, q, k), - prepare=True, - binary=True, - ) - - else: - # No filtering (existing behavior) - if is_reranking: - result = self.cursor.execute( - self._unfiltered_search, - (q, q, k), - prepare=True, - binary=True, - ) - else: - result = self.cursor.execute( - self._unfiltered_search, - (q, k), - prepare=True, - binary=True, - ) - + result = self.cursor.execute( + self._search, + (q, q, k) if search_params.get("reranking", False) else (q, k), + prepare=True, + binary=True, + ) + return [int(i[0]) for i in result.fetchall()] + From 2e4459c7f8cadd3b175ebcd1fb2ce40539bd127d Mon Sep 17 00:00:00 2001 From: Eesha Faisal Date: Fri, 13 Feb 2026 12:44:35 +0500 Subject: [PATCH 3/7] Refactor: remove unrelated function --- .../backend/clients/pgdiskann/pgdiskann.py | 29 ------------------- 1 file changed, 29 deletions(-) diff --git a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py index 23bde85b4..90b87ae01 100644 --- a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +++ b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py @@ -82,35 +82,6 @@ def __init__( self.cursor = None self.conn = None - def get_size_info(self): - try: - assert self.conn is not None, "Connection is not initialized" - assert self.cursor is not None, "Cursor is not initialized" - log.info(f"{self.name} client get size info.") - - size_sql = sql.SQL( - "SELECT pg_size_pretty(pg_table_size('{table_name}')) as table_size, pg_size_pretty(pg_table_size('{index_name}')) as index_size;" - ).format( - table_name=sql.Identifier(self.table_name), - index_name=sql.Identifier(self._index_name), - ) - log.debug(size_sql.as_string(self.cursor)) - self.cursor.execute(size_sql) - self.conn.commit() - result = self.cursor.fetchone() - - if result: - table_size = result[0] - index_size = result[1] - log.info(f"Table Size: {table_size}, Index Size: {index_size}") - return (table_size, index_size) - else: - log.error("No results returned from the query.") - return (0, 0) - except Exception as e: - log.warning("Failed to fetch table and index information") - return (0, 0) - @staticmethod def _create_connection(**kwargs) -> tuple[Connection, Cursor]: conn = psycopg.connect(**kwargs) From d0a7129a99cb1d4b3936b0506a785cdaf3c19fbe Mon Sep 17 00:00:00 2001 From: Eesha Faisal Date: Fri, 17 Apr 2026 12:30:57 +0500 Subject: [PATCH 4/7] style: apply black formatting to pgdiskann.py --- .../backend/clients/pgdiskann/pgdiskann.py | 55 +++++++------------ 1 file changed, 19 insertions(+), 36 deletions(-) diff --git a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py index 90b87ae01..7d0962653 100644 --- a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +++ b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py @@ -29,7 +29,7 @@ class PgDiskANN(VectorDB): conn: psycopg.Connection[Any] | None = None cursor: psycopg.Cursor[Any] | None = None - _search: sql.Composed + _search: sql.Composed def __init__( self, @@ -48,7 +48,7 @@ def __init__( self.dim = dim self.with_scalar_labels = with_scalar_labels self._scalar_label_field = "label" - self.where_clause = "" + self.where_clause = "" self._index_name = "pgdiskann_index" self._primary_field = "id" @@ -101,8 +101,7 @@ def _generate_search_query(self) -> sql.Composed: search_params = self.case_config.search_param() if search_params.get("reranking"): - search_query = sql.SQL( - """ + search_query = sql.SQL(""" SELECT i.id FROM ( SELECT id, embedding @@ -113,8 +112,7 @@ def _generate_search_query(self) -> sql.Composed: ) i ORDER BY i.embedding {reranking_metric_fun_op} %s::vector LIMIT %s::int - """ - ).format( + """).format( table_name=sql.Identifier(self.table_name), where_clause=sql.SQL(self.where_clause), metric_fun_op=sql.SQL(search_params["metric_fun_op"]), @@ -124,9 +122,7 @@ def _generate_search_query(self) -> sql.Composed: else: search_query = sql.Composed( [ - sql.SQL( - "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding " - ).format( + sql.SQL("SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding ").format( table_name=sql.Identifier(self.table_name), where_clause=sql.SQL(self.where_clause), ), @@ -268,18 +264,12 @@ def _create_index(self): ), ) - with_clause = ( - sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) - if any(options) - else sql.Composed(()) - ) + with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(()) - index_create_sql = sql.SQL( - """ + index_create_sql = sql.SQL(""" CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} USING {index_type} (embedding {embedding_metric}) - """ - ).format( + """).format( index_name=sql.Identifier(self._index_name), table_name=sql.Identifier(self.table_name), index_type=sql.Identifier(index_param["index_type"].lower()), @@ -299,12 +289,10 @@ def _create_table(self, dim: int): if self.with_scalar_labels: self.cursor.execute( - sql.SQL( - """ + sql.SQL(""" CREATE TABLE IF NOT EXISTS public.{table_name} ({primary_field} BIGINT PRIMARY KEY, embedding vector({dim}), {label_field} VARCHAR(64)); - """ - ).format( + """).format( table_name=sql.Identifier(self.table_name), dim=dim, primary_field=sql.Identifier(self._primary_field), @@ -313,12 +301,10 @@ def _create_table(self, dim: int): ) else: self.cursor.execute( - sql.SQL( - """ + sql.SQL(""" CREATE TABLE IF NOT EXISTS public.{table_name} ({primary_field} BIGINT PRIMARY KEY, embedding vector({dim})); - """ - ).format( + """).format( table_name=sql.Identifier(self.table_name), dim=dim, primary_field=sql.Identifier(self._primary_field), @@ -326,11 +312,11 @@ def _create_table(self, dim: int): ) self.cursor.execute( - sql.SQL( - "ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;" - ).format(table_name=sql.Identifier(self.table_name)), + sql.SQL("ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;").format( + table_name=sql.Identifier(self.table_name) + ), ) - + self.conn.commit() except Exception as e: log.warning(f"Failed to create pgdiskann table: {self.table_name} error: {e}") @@ -347,9 +333,7 @@ def insert_embeddings( assert self.cursor is not None, "Cursor is not initialized" if self.with_scalar_labels: - assert ( - labels_data is not None - ), "labels_data should be provided if with_scalar_labels is set to True" + assert labels_data is not None, "labels_data should be provided if with_scalar_labels is set to True" try: metadata_arr = np.array(metadata) @@ -405,13 +389,12 @@ def search_embedding( search_params = self.case_config.search_param() q = np.asarray(query) - + result = self.cursor.execute( self._search, (q, q, k) if search_params.get("reranking", False) else (q, k), prepare=True, binary=True, ) - - return [int(i[0]) for i in result.fetchall()] + return [int(i[0]) for i in result.fetchall()] From 56f13dcb4d02512e5d960fa9d16b478005043151 Mon Sep 17 00:00:00 2001 From: Eesha Faisal Date: Fri, 17 Apr 2026 12:38:27 +0500 Subject: [PATCH 5/7] fix: remove trailing whitespace and fix import sorting --- vectordb_bench/backend/clients/pgdiskann/pgdiskann.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py index 7d0962653..46e8fabd4 100644 --- a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +++ b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py @@ -11,6 +11,7 @@ from psycopg import Connection, Cursor, sql from vectordb_bench.backend.filter import Filter, FilterOp + from ..api import VectorDB from .config import PgDiskANNConfigDict, PgDiskANNIndexConfig @@ -290,7 +291,7 @@ def _create_table(self, dim: int): if self.with_scalar_labels: self.cursor.execute( sql.SQL(""" - CREATE TABLE IF NOT EXISTS public.{table_name} + CREATE TABLE IF NOT EXISTS public.{table_name} ({primary_field} BIGINT PRIMARY KEY, embedding vector({dim}), {label_field} VARCHAR(64)); """).format( table_name=sql.Identifier(self.table_name), @@ -302,7 +303,7 @@ def _create_table(self, dim: int): else: self.cursor.execute( sql.SQL(""" - CREATE TABLE IF NOT EXISTS public.{table_name} + CREATE TABLE IF NOT EXISTS public.{table_name} ({primary_field} BIGINT PRIMARY KEY, embedding vector({dim})); """).format( table_name=sql.Identifier(self.table_name), From d10b296b612f6568f5c2abf043b0003d2f4ca8b4 Mon Sep 17 00:00:00 2001 From: Eesha Faisal Date: Mon, 20 Apr 2026 17:34:13 +0500 Subject: [PATCH 6/7] docs: add comments for label naming and vector storage optimization --- vectordb_bench/backend/clients/pgdiskann/pgdiskann.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py index 46e8fabd4..f9b9f0362 100644 --- a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +++ b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py @@ -48,6 +48,8 @@ def __init__( self.table_name = collection_name self.dim = dim self.with_scalar_labels = with_scalar_labels + # Table column for scalar labels (standard naming convention across all database clients). + # Note: Dataset uses "labels" (plural), table uses "label" (singular). self._scalar_label_field = "label" self.where_clause = "" @@ -312,6 +314,7 @@ def _create_table(self, dim: int): ), ) + # Disable TOAST compression on the vector column to improve query performance. self.cursor.execute( sql.SQL("ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;").format( table_name=sql.Identifier(self.table_name) From 07d81d7b9c85f9b3762f4a3c5c0f5c7c592efef8 Mon Sep 17 00:00:00 2001 From: Eesha Faisal Date: Tue, 21 Apr 2026 12:00:09 +0500 Subject: [PATCH 7/7] Revert "docs: add comments for label naming and vector storage optimization" This reverts commit d10b296b612f6568f5c2abf043b0003d2f4ca8b4. --- vectordb_bench/backend/clients/pgdiskann/pgdiskann.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py index f9b9f0362..46e8fabd4 100644 --- a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +++ b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py @@ -48,8 +48,6 @@ def __init__( self.table_name = collection_name self.dim = dim self.with_scalar_labels = with_scalar_labels - # Table column for scalar labels (standard naming convention across all database clients). - # Note: Dataset uses "labels" (plural), table uses "label" (singular). self._scalar_label_field = "label" self.where_clause = "" @@ -314,7 +312,6 @@ def _create_table(self, dim: int): ), ) - # Disable TOAST compression on the vector column to improve query performance. self.cursor.execute( sql.SQL("ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;").format( table_name=sql.Identifier(self.table_name)