Skip to content

Commit 1446c6e

Browse files
nuvotex-tkalwayslove2013
authored andcommitted
Add vespa integration
1 parent a39fe83 commit 1446c6e

11 files changed

Lines changed: 452 additions & 2 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ All the database client supported
5656
| aliyun_opensearch | `pip install vectordb-bench[aliyun_opensearch]` |
5757
| mongodb | `pip install vectordb-bench[mongodb]` |
5858
| tidb | `pip install vectordb-bench[tidb]` |
59+
| vespa | `pip install vectordb-bench[vespa]` |
5960

6061
### Run
6162

install/requirements_py3.11.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ pydantic<v2
2323
scikit-learn
2424
pymilvus
2525
clickhouse_connect
26+
pyvespa

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ all = [
7070
"alibabacloud_searchengine20211025",
7171
"mariadb",
7272
"PyMySQL",
73-
"clickhouse-connect"
73+
"clickhouse-connect",
74+
"pyvespa",
7475
]
7576

7677
qdrant = [ "qdrant-client" ]
@@ -92,6 +93,7 @@ mongodb = [ "pymongo" ]
9293
mariadb = [ "mariadb" ]
9394
tidb = [ "PyMySQL" ]
9495
clickhouse = [ "clickhouse-connect" ]
96+
vespa = [ "pyvespa" ]
9597

9698
[project.urls]
9799
"repository" = "https://github.com/zilliztech/VectorDBBench"

vectordb_bench/backend/clients/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class DB(Enum):
4444
MongoDB = "MongoDB"
4545
TiDB = "TiDB"
4646
Clickhouse = "Clickhouse"
47+
Vespa = "Vespa"
4748

4849
@property
4950
def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901
@@ -157,6 +158,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901
157158
from .test.test import Test
158159

159160
return Test
161+
162+
if self == DB.Vespa:
163+
from .vespa.vespa import Vespa
164+
165+
return Vespa
160166

161167
msg = f"Unknown DB: {self.name}"
162168
raise ValueError(msg)
@@ -273,6 +279,12 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901
273279
from .test.config import TestConfig
274280

275281
return TestConfig
282+
283+
if self == DB.Vespa:
284+
from .vespa.config import VespaConfig
285+
286+
return VespaConfig
287+
276288

277289
msg = f"Unknown DB: {self.name}"
278290
raise ValueError(msg)
@@ -365,6 +377,11 @@ def case_config_cls( # noqa: PLR0911
365377
from .tidb.config import TiDBIndexConfig
366378

367379
return TiDBIndexConfig
380+
381+
if self == DB.Vespa:
382+
from .vespa.config import VespaHNSWConfig
383+
384+
return VespaHNSWConfig
368385

369386
# DB.Pinecone, DB.Chroma, DB.Redis
370387
return EmptyDBCaseConfig
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Annotated, Unpack
2+
3+
import click
4+
from pydantic import SecretStr
5+
6+
from vectordb_bench.backend.clients import DB
7+
from vectordb_bench.cli.cli import (
8+
CommonTypedDict,
9+
HNSWFlavor1,
10+
cli,
11+
click_parameter_decorators_from_typed_dict,
12+
run,
13+
)
14+
15+
16+
class VespaTypedDict(CommonTypedDict, HNSWFlavor1):
17+
uri: Annotated[
18+
str,
19+
click.option("--uri", "-u", type=str, help="uri connection string", default="http://127.0.0.1"),
20+
]
21+
port: Annotated[
22+
int,
23+
click.option("--port", "-p", type=int, help="connection port", default=8080),
24+
]
25+
quantization: Annotated[
26+
str, click.option("--quantization", type=click.Choice(["none", "binary"], case_sensitive=False), default="none")
27+
]
28+
29+
30+
@cli.command()
31+
@click_parameter_decorators_from_typed_dict(VespaTypedDict)
32+
def Vespa(**params: Unpack[VespaTypedDict]):
33+
from .config import VespaConfig, VespaHNSWConfig
34+
35+
case_params = {
36+
"quantization_type": params["quantization"],
37+
"M": params["m"],
38+
"efConstruction": params["ef_construction"],
39+
"ef": params["ef_search"],
40+
}
41+
42+
run(
43+
db=DB.Vespa,
44+
db_config=VespaConfig(url=SecretStr(params["uri"]), port=params["port"]),
45+
db_case_config=VespaHNSWConfig(**{k: v for k, v in case_params.items() if v}),
46+
**params,
47+
)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from typing import Literal, TypeAlias
2+
3+
from pydantic import BaseModel, SecretStr
4+
5+
from ..api import DBCaseConfig, DBConfig, MetricType
6+
7+
VespaMetric: TypeAlias = Literal["euclidean", "angular", "dotproduct", "prenormalized-angular", "hamming", "geodegrees"]
8+
9+
VespaQuantizationType: TypeAlias = Literal["none", "binary"]
10+
11+
12+
class VespaConfig(DBConfig):
13+
url: SecretStr = "http://127.0.0.1"
14+
port: int = 8080
15+
16+
def to_dict(self):
17+
return {
18+
"url": self.url.get_secret_value(),
19+
"port": self.port,
20+
}
21+
22+
23+
class VespaHNSWConfig(BaseModel, DBCaseConfig):
24+
metric_type: MetricType = MetricType.COSINE
25+
quantization_type: VespaQuantizationType = "none"
26+
M: int = 16
27+
efConstruction: int = 200
28+
ef: int = 100
29+
30+
def index_param(self) -> dict:
31+
return {
32+
"distance_metric": self.parse_metric(self.metric_type),
33+
"max_links_per_node": self.M,
34+
"neighbors_to_explore_at_insert": self.efConstruction,
35+
}
36+
37+
def search_param(self) -> dict:
38+
return {}
39+
40+
def parse_metric(self, metric_type: MetricType) -> VespaMetric:
41+
match metric_type:
42+
case MetricType.COSINE:
43+
return "angular"
44+
case MetricType.L2:
45+
return "euclidean"
46+
case MetricType.DP | MetricType.IP:
47+
return "dotproduct"
48+
case MetricType.HAMMING:
49+
return "hamming"
50+
case _:
51+
raise NotImplementedError
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""Utility functions for supporting binary quantization
2+
3+
From https://docs.vespa.ai/en/binarizing-vectors.html#appendix-conversion-to-int8
4+
"""
5+
import numpy as np
6+
7+
8+
def binarize_tensor(tensor: list[float]) -> list[int]:
9+
"""
10+
Binarize a floating-point list by thresholding at zero
11+
and packing the bits into bytes.
12+
"""
13+
tensor = np.array(tensor)
14+
return (
15+
np.packbits(np.where(tensor > 0, 1, 0), axis=0).astype(np.int8).tolist()
16+
)

0 commit comments

Comments
 (0)