diff --git a/pyproject.toml b/pyproject.toml index c4563b26e..e63e85387 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ all = [ "pyvespa", "lancedb", "mysql-connector-python", + "turbopuffer[fast]", ] qdrant = [ "qdrant-client" ] @@ -102,6 +103,7 @@ lancedb = [ "lancedb" ] oceanbase = [ "mysql-connector-python" ] alisql = [ "mysql-connector-python" ] doris = [ "doris-vector-search" ] +turbopuffer = [ "turbopuffer" ] [project.urls] "repository" = "https://github.com/zilliztech/VectorDBBench" diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index e96622970..3bd2d3189 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -55,7 +55,7 @@ class DB(Enum): TencentElasticsearch = "TencentElasticsearch" AliSQL = "AlibabaCloudRDSMySQL" Doris = "Doris" - TurboPuffer = "TurpoBuffer" + TurboPuffer = "TurboPuffer" @property def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915 diff --git a/vectordb_bench/backend/clients/turbopuffer/cli.py b/vectordb_bench/backend/clients/turbopuffer/cli.py new file mode 100644 index 000000000..6fd91f2a8 --- /dev/null +++ b/vectordb_bench/backend/clients/turbopuffer/cli.py @@ -0,0 +1,62 @@ +from typing import Annotated, TypedDict, Unpack + +import click +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from .. import DB + + +class TurboPufferTypedDict(TypedDict): + api_key: Annotated[ + str, + click.option("--api-key", type=str, help="TurboPuffer API key", required=True), + ] + api_base_url: Annotated[ + str, + click.option( + "--api-base-url", + type=str, + help="TurboPuffer API base URL", + required=False, + default="https://api.turbopuffer.com", + show_default=True, + ), + ] + namespace: Annotated[ + str, + click.option( + "--namespace", + type=str, + help="TurboPuffer namespace", + required=False, + default="vdbbench_test", + show_default=True, + ), + ] + + +class TurboPufferIndexTypedDict(CommonTypedDict, TurboPufferTypedDict): ... + + +@cli.command() +@click_parameter_decorators_from_typed_dict(TurboPufferIndexTypedDict) +def TurboPuffer(**parameters: Unpack[TurboPufferIndexTypedDict]): + from .config import TurboPufferConfig, TurboPufferIndexConfig + + run( + db=DB.TurboPuffer, + db_config=TurboPufferConfig( + db_label=parameters["db_label"], + api_key=SecretStr(parameters["api_key"]), + api_base_url=parameters["api_base_url"], + namespace=parameters["namespace"], + ), + db_case_config=TurboPufferIndexConfig(), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/turbopuffer/turbopuffer.py b/vectordb_bench/backend/clients/turbopuffer/turbopuffer.py index 449682e8c..241551792 100644 --- a/vectordb_bench/backend/clients/turbopuffer/turbopuffer.py +++ b/vectordb_bench/backend/clients/turbopuffer/turbopuffer.py @@ -42,11 +42,13 @@ def __init__( self._scalar_label_field = "label" self.with_scalar_labels = with_scalar_labels + + # Initialize client with new SDK pattern + self.client = tpuf.Turbopuffer(api_key=self.api_key, base_url=self.api_base_url) + if drop_old: log.info(f"Drop old. delete the namespace: {self.namespace}") - tpuf.api_key = self.api_key - tpuf.api_base_url = self.api_base_url - ns = tpuf.Namespace(self.namespace) + ns = self.client.namespace(self.namespace) try: ns.delete_all() except Exception as e: @@ -54,9 +56,7 @@ def __init__( @contextmanager def init(self): - tpuf.api_key = self.api_key - tpuf.api_base_url = self.api_base_url - self.ns = tpuf.Namespace(self.namespace) + self.ns = self.client.namespace(self.namespace) yield def optimize(self, data_size: int | None = None): @@ -78,7 +78,7 @@ def insert_embeddings( try: if self.with_scalar_labels: self.ns.write( - upsert_columns={ + columns={ self._scalar_id_field: metadata, self._vector_field: embeddings, self._scalar_label_field: labels_data, @@ -87,7 +87,7 @@ def insert_embeddings( ) else: self.ns.write( - upsert_columns={ + columns={ self._scalar_id_field: metadata, self._vector_field: embeddings, }, @@ -104,19 +104,19 @@ def search_embedding( timeout: int | None = None, ) -> list[int]: res = self.ns.query( - rank_by=["vector", "ANN", query], + rank_by=("vector", "ANN", query), top_k=k, filters=self.expr, ) - return [row.id for row in res.rows] + return [row.id for row in res.rows] if res.rows is not None else [] def prepare_filter(self, filters: Filter): if filters.type == FilterOp.NonFilter: self.expr = None elif filters.type == FilterOp.NumGE: - self.expr = [self._scalar_id_field, "Gte", filters.int_value] + self.expr = (self._scalar_id_field, "Gte", filters.int_value) elif filters.type == FilterOp.StrEqual: - self.expr = [self._scalar_label_field, "Eq", filters.label_value] + self.expr = (self._scalar_label_field, "Eq", filters.label_value) else: msg = f"Not support Filter for TurboPuffer - {filters}" raise ValueError(msg) diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 1c8df6c19..f25f5649f 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -22,6 +22,7 @@ from ..backend.clients.tencent_elasticsearch.cli import TencentElasticsearch from ..backend.clients.test.cli import Test from ..backend.clients.tidb.cli import TiDB +from ..backend.clients.turbopuffer.cli import TurboPuffer from ..backend.clients.vespa.cli import Vespa from ..backend.clients.weaviate_cloud.cli import Weaviate from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex @@ -58,6 +59,7 @@ cli.add_command(TencentElasticsearch) cli.add_command(AliSQLHNSW) cli.add_command(Doris) +cli.add_command(TurboPuffer) if __name__ == "__main__":