Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ all = [
"pyvespa",
"lancedb",
"mysql-connector-python",
"turbopuffer[fast]",
]

qdrant = [ "qdrant-client" ]
Expand Down Expand Up @@ -102,6 +103,7 @@ lancedb = [ "lancedb" ]
oceanbase = [ "mysql-connector-python" ]
alisql = [ "mysql-connector-python" ]
doris = [ "doris-vector-search" ]
turbopuffer = [ "turbopuffer" ]

[project.urls]
"repository" = "https://github.com/zilliztech/VectorDBBench"
Expand Down
2 changes: 1 addition & 1 deletion vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class DB(Enum):
TencentElasticsearch = "TencentElasticsearch"
AliSQL = "AlibabaCloudRDSMySQL"
Doris = "Doris"
TurboPuffer = "TurpoBuffer"
TurboPuffer = "TurboPuffer"

@property
def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915
Expand Down
62 changes: 62 additions & 0 deletions vectordb_bench/backend/clients/turbopuffer/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Annotated, TypedDict, Unpack

import click
from pydantic import SecretStr

from ....cli.cli import (
CommonTypedDict,
cli,
click_parameter_decorators_from_typed_dict,
run,
)
from .. import DB


class TurboPufferTypedDict(TypedDict):
api_key: Annotated[
str,
click.option("--api-key", type=str, help="TurboPuffer API key", required=True),
]
api_base_url: Annotated[
str,
click.option(
"--api-base-url",
type=str,
help="TurboPuffer API base URL",
required=False,
default="https://api.turbopuffer.com",
show_default=True,
),
]
namespace: Annotated[
str,
click.option(
"--namespace",
type=str,
help="TurboPuffer namespace",
required=False,
default="vdbbench_test",
show_default=True,
),
]


class TurboPufferIndexTypedDict(CommonTypedDict, TurboPufferTypedDict): ...


@cli.command()
@click_parameter_decorators_from_typed_dict(TurboPufferIndexTypedDict)
def TurboPuffer(**parameters: Unpack[TurboPufferIndexTypedDict]):
from .config import TurboPufferConfig, TurboPufferIndexConfig

run(
db=DB.TurboPuffer,
db_config=TurboPufferConfig(
db_label=parameters["db_label"],
api_key=SecretStr(parameters["api_key"]),
api_base_url=parameters["api_base_url"],
namespace=parameters["namespace"],
),
db_case_config=TurboPufferIndexConfig(),
**parameters,
)
24 changes: 12 additions & 12 deletions vectordb_bench/backend/clients/turbopuffer/turbopuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@ def __init__(
self._scalar_label_field = "label"

self.with_scalar_labels = with_scalar_labels

# Initialize client with new SDK pattern
self.client = tpuf.Turbopuffer(api_key=self.api_key, base_url=self.api_base_url)

if drop_old:
log.info(f"Drop old. delete the namespace: {self.namespace}")
tpuf.api_key = self.api_key
tpuf.api_base_url = self.api_base_url
ns = tpuf.Namespace(self.namespace)
ns = self.client.namespace(self.namespace)
try:
ns.delete_all()
except Exception as e:
log.warning(f"Failed to delete all. Error: {e}")

@contextmanager
def init(self):
tpuf.api_key = self.api_key
tpuf.api_base_url = self.api_base_url
self.ns = tpuf.Namespace(self.namespace)
self.ns = self.client.namespace(self.namespace)
yield

def optimize(self, data_size: int | None = None):
Expand All @@ -78,7 +78,7 @@ def insert_embeddings(
try:
if self.with_scalar_labels:
self.ns.write(
upsert_columns={
columns={
self._scalar_id_field: metadata,
self._vector_field: embeddings,
self._scalar_label_field: labels_data,
Expand All @@ -87,7 +87,7 @@ def insert_embeddings(
)
else:
self.ns.write(
upsert_columns={
columns={
self._scalar_id_field: metadata,
self._vector_field: embeddings,
},
Expand All @@ -104,19 +104,19 @@ def search_embedding(
timeout: int | None = None,
) -> list[int]:
res = self.ns.query(
rank_by=["vector", "ANN", query],
rank_by=("vector", "ANN", query),
top_k=k,
filters=self.expr,
)
return [row.id for row in res.rows]
return [row.id for row in res.rows] if res.rows is not None else []

def prepare_filter(self, filters: Filter):
if filters.type == FilterOp.NonFilter:
self.expr = None
elif filters.type == FilterOp.NumGE:
self.expr = [self._scalar_id_field, "Gte", filters.int_value]
self.expr = (self._scalar_id_field, "Gte", filters.int_value)
elif filters.type == FilterOp.StrEqual:
self.expr = [self._scalar_label_field, "Eq", filters.label_value]
self.expr = (self._scalar_label_field, "Eq", filters.label_value)
else:
msg = f"Not support Filter for TurboPuffer - {filters}"
raise ValueError(msg)
2 changes: 2 additions & 0 deletions vectordb_bench/cli/vectordbbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ..backend.clients.tencent_elasticsearch.cli import TencentElasticsearch
from ..backend.clients.test.cli import Test
from ..backend.clients.tidb.cli import TiDB
from ..backend.clients.turbopuffer.cli import TurboPuffer
from ..backend.clients.vespa.cli import Vespa
from ..backend.clients.weaviate_cloud.cli import Weaviate
from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex
Expand Down Expand Up @@ -58,6 +59,7 @@
cli.add_command(TencentElasticsearch)
cli.add_command(AliSQLHNSW)
cli.add_command(Doris)
cli.add_command(TurboPuffer)


if __name__ == "__main__":
Expand Down