diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index 21903594e..c096a14c8 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -47,7 +47,6 @@ class DB(Enum): Clickhouse = "Clickhouse" Vespa = "Vespa" LanceDB = "LanceDB" - @property def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915 @@ -76,10 +75,10 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915 from .qdrant_cloud.qdrant_cloud import QdrantCloud return QdrantCloud - + if self == DB.QdrantLocal: from .qdrant_local.qdrant_local import QdrantLocal - + return QdrantLocal if self == DB.WeaviateCloud: @@ -207,10 +206,12 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901, PLR0915 from .qdrant_cloud.config import QdrantConfig return QdrantConfig - + if self == DB.QdrantLocal: from .qdrant_local.config import QdrantLocalConfig + return QdrantLocalConfig + if self == DB.WeaviateCloud: from .weaviate_cloud.config import WeaviateConfig @@ -332,10 +333,10 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912 from .qdrant_cloud.config import QdrantIndexConfig return QdrantIndexConfig - + if self == DB.QdrantLocal: from .qdrant_local.config import QdrantLocalIndexConfig - + return QdrantLocalIndexConfig if self == DB.WeaviateCloud: diff --git a/vectordb_bench/backend/clients/qdrant_local/cli.py b/vectordb_bench/backend/clients/qdrant_local/cli.py index c01f0afb7..7995b99b3 100644 --- a/vectordb_bench/backend/clients/qdrant_local/cli.py +++ b/vectordb_bench/backend/clients/qdrant_local/cli.py @@ -1,4 +1,4 @@ -from typing import Annotated, TypedDict, Unpack +from typing import Annotated, Unpack import click from pydantic import SecretStr @@ -11,7 +11,6 @@ run, ) - DBTYPE = DB.QdrantLocal @@ -22,29 +21,27 @@ class QdrantLocalTypedDict(CommonTypedDict): ] on_disk: Annotated[ bool, - click.option( - "--on-disk", type=bool, default=False, help="Store the vectors and the HNSW index on disk" - ), + click.option("--on-disk", type=bool, default=False, help="Store the vectors and the HNSW index on disk"), ] m: Annotated[ int, - click.option( - "--m", type=int, default=16, help="HNSW index parameter m, set 0 to disable the index" - ), + click.option("--m", type=int, default=16, help="HNSW index parameter m, set 0 to disable the index"), ] ef_construct: Annotated[ int, - click.option( - "--ef-construct", type=int, default=200, help="HNSW index parameter ef_construct" - ), + click.option("--ef-construct", type=int, default=200, help="HNSW index parameter ef_construct"), ] hnsw_ef: Annotated[ int, click.option( - "--hnsw-ef", type=int, default=0, help="HNSW index parameter hnsw_ef, set 0 to use ef_construct for search", + "--hnsw-ef", + type=int, + default=0, + help="HNSW index parameter hnsw_ef, set 0 to use ef_construct for search", ), ] + @cli.command() @click_parameter_decorators_from_typed_dict(QdrantLocalTypedDict) def QdrantLocal(**parameters: Unpack[QdrantLocalTypedDict]): @@ -52,9 +49,7 @@ def QdrantLocal(**parameters: Unpack[QdrantLocalTypedDict]): run( db=DBTYPE, - db_config=QdrantLocalConfig( - url=SecretStr(parameters["url"]) - ), + db_config=QdrantLocalConfig(url=SecretStr(parameters["url"])), db_case_config=QdrantLocalIndexConfig( on_disk=parameters["on_disk"], m=parameters["m"], diff --git a/vectordb_bench/backend/clients/qdrant_local/config.py b/vectordb_bench/backend/clients/qdrant_local/config.py index b2949313f..ebdf99dc4 100644 --- a/vectordb_bench/backend/clients/qdrant_local/config.py +++ b/vectordb_bench/backend/clients/qdrant_local/config.py @@ -1,10 +1,11 @@ from pydantic import BaseModel, SecretStr -from ..api import DBCaseConfig, DBConfig, IndexType, MetricType +from ..api import DBCaseConfig, DBConfig, MetricType + class QdrantLocalConfig(DBConfig): url: SecretStr - + def to_dict(self) -> dict: return { "url": self.url.get_secret_value(), @@ -17,7 +18,7 @@ class QdrantLocalIndexConfig(BaseModel, DBCaseConfig): ef_construct: int hnsw_ef: int | None = 0 on_disk: bool | None = False - + def parse_metric(self) -> str: if self.metric_type == MetricType.L2: return "Euclid" @@ -26,7 +27,7 @@ def parse_metric(self) -> str: return "Dot" return "Cosine" - + def index_param(self) -> dict: return { "distance": self.parse_metric(), @@ -34,13 +35,13 @@ def index_param(self) -> dict: "ef_construct": self.ef_construct, "on_disk": self.on_disk, } - + def search_param(self) -> dict: search_params = { - "exact": False, # Force to use ANNs + "exact": False, # Force to use ANNs } - + if self.hnsw_ef != 0: search_params["hnsw_ef"] = self.hnsw_ef - - return search_params \ No newline at end of file + + return search_params diff --git a/vectordb_bench/backend/clients/qdrant_local/qdrant_local.py b/vectordb_bench/backend/clients/qdrant_local/qdrant_local.py index 723808e8f..1340be614 100644 --- a/vectordb_bench/backend/clients/qdrant_local/qdrant_local.py +++ b/vectordb_bench/backend/clients/qdrant_local/qdrant_local.py @@ -28,22 +28,23 @@ QDRANT_BATCH_SIZE = 100 -def qdrant_collection_exists(client, collection_name: str) -> bool: +def qdrant_collection_exists(client: QdrantClient, collection_name: str) -> bool: collection_exists = True - + try: client.get_collection(collection_name) - except Exception as e: + except Exception: collection_exists = False - + return collection_exists - + + class QdrantLocal(VectorDB): def __init__( self, dim: int, db_config: dict, - db_case_config: dict, + db_case_config: QdrantLocalIndexConfig, collection_name: str = "QdrantLocalCollection", drop_old: bool = False, name: str = "QdrantLocal", @@ -56,26 +57,26 @@ def __init__( self.search_parameter = self.case_config.search_param() self.collection_name = collection_name self.client = None - + self._primary_field = "pk" self._vector_field = "vector" - + client = QdrantClient(**self.db_config) - + # Lets just print the parameters here for double check log.info(f"Case config: {self.case_config.index_param()}") log.info(f"Search parameter: {self.search_parameter}") - + if drop_old and qdrant_collection_exists(client, self.collection_name): log.info(f"{self.name} client drop_old collection: {self.collection_name}") client.delete_collection(self.collection_name) - + if not qdrant_collection_exists(client, self.collection_name): log.info(f"{self.name} create collection: {self.collection_name}") self._create_collection(dim, client) client = None - + @contextmanager def init(self): """ @@ -89,11 +90,15 @@ def init(self): yield self.client = None del self.client - + def _create_collection(self, dim: int, qdrant_client: QdrantClient): log.info(f"Create collection: {self.collection_name}") - log.info(f"Index parameters: m={self.case_config.index_param()['m']}, ef_construct={self.case_config.index_param()['ef_construct']}, on_disk={self.case_config.index_param()['on_disk']}") - + log.info( + f"Index parameters: m={self.case_config.index_param()['m']}, " + f"ef_construct={self.case_config.index_param()['ef_construct']}, " + f"on_disk={self.case_config.index_param()['on_disk']}" + ) + # If the on_disk is true, we enable both on disk index and vectors. try: qdrant_client.create_collection( @@ -104,10 +109,10 @@ def _create_collection(self, dim: int, qdrant_client: QdrantClient): on_disk=self.case_config.index_param()["on_disk"], ), hnsw_config=HnswConfigDiff( - m = self.case_config.index_param()["m"], + m=self.case_config.index_param()["m"], ef_construct=self.case_config.index_param()["ef_construct"], on_disk=self.case_config.index_param()["on_disk"], - ) + ), ) qdrant_client.create_payload_index( @@ -121,7 +126,7 @@ def _create_collection(self, dim: int, qdrant_client: QdrantClient): return log.warning(f"Failed to create collection: {self.collection_name} error: {e}") raise e from None - + def optimize(self, data_size: int | None = None): assert self.client, "Please call self.init() before" # wait for vectors to be fully indexed @@ -139,11 +144,11 @@ def optimize(self, data_size: int | None = None): ) log.info(msg) return - + except Exception as e: log.warning(f"QdrantCloud ready to search error: {e}") raise e from None - + def insert_embeddings( self, embeddings: Iterable[list[float]], @@ -163,7 +168,7 @@ def insert_embeddings( assert self.client is not None assert len(embeddings) == len(metadata) insert_count = 0 - + # disable indexing for quick insertion self.client.update_collection( collection_name=self.collection_name, @@ -185,13 +190,13 @@ def insert_embeddings( collection_name=self.collection_name, optimizer_config=OptimizersConfigDiff(indexing_threshold=100), ) - + except Exception as e: log.info(f"Failed to insert data, {e}") return insert_count, e else: return insert_count, None - + def search_embedding( self, query: list[float], @@ -203,7 +208,7 @@ def search_embedding( Should call self.init() first. """ assert self.client is not None - + f = None if filters: f = Filter( @@ -215,17 +220,13 @@ def search_embedding( ), ), ], - ) - res = ( - self.client.query_points( - collection_name=self.collection_name, - query=query, - limit=k, - query_filter=f, - search_params=SearchParams(**self.search_parameter), - - ).points - ) - - return [result.id for result in res] + ) + res = self.client.query_points( + collection_name=self.collection_name, + query=query, + limit=k, + query_filter=f, + search_params=SearchParams(**self.search_parameter), + ).points + return [result.id for result in res] diff --git a/vectordb_bench/backend/clients/weaviate_cloud/cli.py b/vectordb_bench/backend/clients/weaviate_cloud/cli.py index 9faf768a6..cba3c2377 100644 --- a/vectordb_bench/backend/clients/weaviate_cloud/cli.py +++ b/vectordb_bench/backend/clients/weaviate_cloud/cli.py @@ -15,7 +15,7 @@ class WeaviateTypedDict(CommonTypedDict): api_key: Annotated[ str, - click.option("--api-key", type=str, help="Weaviate api key", required=False, default=''), + click.option("--api-key", type=str, help="Weaviate api key", required=False, default=""), ] url: Annotated[ str, @@ -23,25 +23,24 @@ class WeaviateTypedDict(CommonTypedDict): ] no_auth: Annotated[ bool, - click.option("--no-auth", is_flag=True, help="Do not use api-key, set it to true if you are using a local setup. Default is False.", default=False), + click.option( + "--no-auth", + is_flag=True, + help="Do not use api-key, set it to true if you are using a local setup. Default is False.", + default=False, + ), ] m: Annotated[ int, - click.option( - "--m", type=int, default=16, help="HNSW index parameter m." - ), + click.option("--m", type=int, default=16, help="HNSW index parameter m."), ] ef_construct: Annotated[ int, - click.option( - "--ef-construction", type=int, default=256, help="HNSW index parameter ef_construction" - ), + click.option("--ef-construction", type=int, default=256, help="HNSW index parameter ef_construction"), ] ef: Annotated[ int, - click.option( - "--ef", type=int, default=256, help="HNSW index parameter ef for search" - ), + click.option("--ef", type=int, default=256, help="HNSW index parameter ef for search"), ] @@ -54,7 +53,7 @@ def Weaviate(**parameters: Unpack[WeaviateTypedDict]): db=DB.WeaviateCloud, db_config=WeaviateConfig( db_label=parameters["db_label"], - api_key=SecretStr(parameters["api_key"]) if parameters["api_key"] != '' else SecretStr("-"), + api_key=SecretStr(parameters["api_key"]) if parameters["api_key"] != "" else SecretStr("-"), url=SecretStr(parameters["url"]), no_auth=parameters["no_auth"], ), diff --git a/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py b/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py index 18a17a661..d6111c8da 100644 --- a/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +++ b/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py @@ -37,11 +37,11 @@ def __init__( self._scalar_field = "key" self._vector_field = "vector" self._index_name = "vector_idx" - - # If local setup is used, we - if db_config['no_auth']: - del db_config['auth_client_secret'] - del db_config['no_auth'] + + # If local setup is used, we + if db_config["no_auth"]: + del db_config["auth_client_secret"] + del db_config["no_auth"] from weaviate import Client