Skip to content

Commit a39fe83

Browse files
MansorY23alwayslove2013
authored andcommitted
feat: initial commit
1 parent 4cbfef7 commit a39fe83

7 files changed

Lines changed: 292 additions & 0 deletions

File tree

install/requirements_py3.11.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ environs
2222
pydantic<v2
2323
scikit-learn
2424
pymilvus
25+
clickhouse_connect

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ all = [
7070
"alibabacloud_searchengine20211025",
7171
"mariadb",
7272
"PyMySQL",
73+
"clickhouse-connect"
7374
]
7475

7576
qdrant = [ "qdrant-client" ]
@@ -90,6 +91,7 @@ aliyun_opensearch = [ "alibabacloud_ha3engine_vector", "alibabacloud_searchengin
9091
mongodb = [ "pymongo" ]
9192
mariadb = [ "mariadb" ]
9293
tidb = [ "PyMySQL" ]
94+
clickhouse = [ "clickhouse-connect" ]
9395

9496
[project.urls]
9597
"repository" = "https://github.com/zilliztech/VectorDBBench"

vectordb_bench/backend/clients/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class DB(Enum):
4343
AliyunOpenSearch = "AliyunOpenSearch"
4444
MongoDB = "MongoDB"
4545
TiDB = "TiDB"
46+
Clickhouse = "Clickhouse"
4647

4748
@property
4849
def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901
@@ -117,6 +118,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901
117118

118119
return AWSOpenSearch
119120

121+
if self == DB.Clickhouse:
122+
from .clickhouse.clickhouse import Clickhouse
123+
124+
return Clickhouse
125+
120126
if self == DB.AlloyDB:
121127
from .alloydb.alloydb import AlloyDB
122128

@@ -228,6 +234,11 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901
228234

229235
return AWSOpenSearchConfig
230236

237+
if self == DB.Clickhouse:
238+
from .clickhouse.config import ClickhouseConfig
239+
240+
return ClickhouseConfig
241+
231242
if self == DB.AlloyDB:
232243
from .alloydb.config import AlloyDBConfig
233244

@@ -310,6 +321,11 @@ def case_config_cls( # noqa: PLR0911
310321

311322
return AWSOpenSearchIndexConfig
312323

324+
if self == DB.Clickhouse:
325+
from .clickhouse.config import ClickhouseHNSWConfig
326+
327+
return ClickhouseHNSWConfig
328+
313329
if self == DB.PgVectorScale:
314330
from .pgvectorscale.config import _pgvectorscale_case_config
315331

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import Annotated, TypedDict, Unpack
2+
3+
import click
4+
from pydantic import SecretStr
5+
6+
from ....cli.cli import (
7+
CommonTypedDict,
8+
HNSWFlavor2,
9+
cli,
10+
click_parameter_decorators_from_typed_dict,
11+
run,
12+
)
13+
from .. import DB
14+
from .config import ClickhouseHNSWConfig
15+
16+
17+
class ClickhouseTypedDict(TypedDict):
18+
password: Annotated[str, click.option("--password", type=str, help="DB password")]
19+
host: Annotated[str, click.option("--host", type=str, help="DB host", required=True)]
20+
port: Annotated[int, click.option("--port", type=int, default=8123, help="DB Port")]
21+
user: Annotated[int, click.option("--user", type=str, default='clickhouse', help="DB user")]
22+
ssl: Annotated[
23+
bool,
24+
click.option(
25+
"--ssl/--no-ssl",
26+
is_flag=True,
27+
show_default=True,
28+
default=True,
29+
help="Enable or disable SSL for Clickhouse",
30+
),
31+
]
32+
ssl_ca_certs: Annotated[
33+
str,
34+
click.option(
35+
"--ssl-ca-certs",
36+
show_default=True,
37+
help="Path to certificate authority file to use for SSL",
38+
),
39+
]
40+
41+
42+
class ClickhouseHNSWTypedDict(CommonTypedDict, ClickhouseTypedDict, HNSWFlavor2): ...
43+
44+
45+
@cli.command()
46+
@click_parameter_decorators_from_typed_dict(ClickhouseHNSWTypedDict)
47+
def Clickhouse(**parameters: Unpack[ClickhouseHNSWTypedDict]):
48+
from .config import ClickhouseConfig
49+
50+
run(
51+
db=DB.Clickhouse,
52+
db_config=ClickhouseConfig(
53+
db_label=parameters["db_label"],
54+
password=SecretStr(parameters["password"]) if parameters["password"] else None,
55+
host=parameters["host"],
56+
port=parameters["port"],
57+
ssl=parameters["ssl"],
58+
ssl_ca_certs=parameters["ssl_ca_certs"],
59+
),
60+
db_case_config=ClickhouseHNSWConfig(
61+
M=parameters["m"],
62+
efConstruction=parameters["ef_construction"],
63+
ef=parameters["ef_runtime"],
64+
),
65+
**parameters,
66+
)
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""Wrapper around the Clickhouse vector database over VectorDB"""
2+
3+
import io
4+
import logging
5+
from contextlib import contextmanager
6+
from typing import Any
7+
import clickhouse_connect
8+
import numpy as np
9+
10+
from ..api import VectorDB, DBCaseConfig
11+
12+
log = logging.getLogger(__name__)
13+
14+
class Clickhouse(VectorDB):
15+
"""Use SQLAlchemy instructions"""
16+
def __init__(
17+
self,
18+
dim: int,
19+
db_config: dict,
20+
db_case_config: DBCaseConfig,
21+
collection_name: str = "CHVectorCollection",
22+
drop_old: bool = False,
23+
**kwargs,
24+
):
25+
self.db_config = db_config
26+
self.case_config = db_case_config
27+
self.table_name = collection_name
28+
self.dim = dim
29+
30+
self._index_name = "clickhouse_index"
31+
self._primary_field = "id"
32+
self._vector_field = "embedding"
33+
34+
# construct basic units
35+
self.conn = clickhouse_connect.get_client(
36+
host=self.db_config["host"],
37+
port=self.db_config["port"],
38+
username=self.db_config["user"],
39+
password=self.db_config["password"],
40+
database=self.db_config["dbname"])
41+
42+
if drop_old:
43+
log.info(f"Clickhouse client drop table : {self.table_name}")
44+
self._drop_table()
45+
self._create_table(dim)
46+
47+
self.conn.close()
48+
self.conn = None
49+
50+
@contextmanager
51+
def init(self) -> None:
52+
"""
53+
Examples:
54+
>>> with self.init():
55+
>>> self.insert_embeddings()
56+
>>> self.search_embedding()
57+
"""
58+
59+
self.conn = clickhouse_connect.get_client(
60+
host=self.db_config["host"],
61+
port=self.db_config["port"],
62+
username=self.db_config["user"],
63+
password=self.db_config["password"],
64+
database=self.db_config["dbname"])
65+
66+
try:
67+
yield
68+
finally:
69+
self.conn.close()
70+
self.conn = None
71+
72+
def _drop_table(self):
73+
assert self.conn is not None, "Connection is not initialized"
74+
75+
self.conn.command(f'DROP TABLE IF EXISTS {self.db_config["dbname"]}.{self.table_name}')
76+
77+
def _create_table(self, dim: int):
78+
assert self.conn is not None, "Connection is not initialized"
79+
80+
try:
81+
# create table
82+
self.conn.command(
83+
f'CREATE TABLE IF NOT EXISTS {self.db_config["dbname"]}.{self.table_name} \
84+
(id UInt32, embedding Array(Float64)) ENGINE = MergeTree() ORDER BY id;'
85+
)
86+
87+
except Exception as e:
88+
log.warning(
89+
f"Failed to create Clickhouse table: {self.table_name} error: {e}"
90+
)
91+
raise e from None
92+
93+
def ready_to_load(self):
94+
pass
95+
96+
def optimize(self, data_size: int | None = None):
97+
pass
98+
99+
def ready_to_search(self):
100+
pass
101+
102+
def insert_embeddings(
103+
self,
104+
embeddings: list[list[float]],
105+
metadata: list[int],
106+
**kwargs: Any,
107+
) -> (int, Exception):
108+
assert self.conn is not None, "Connection is not initialized"
109+
110+
try:
111+
# do not iterate for bulk insert
112+
items = [metadata, embeddings]
113+
114+
self.conn.insert(table=self.table_name, data=items,
115+
column_names=['id', 'embedding'], column_type_names=['UInt32', 'Array(Float64)'],
116+
column_oriented=True)
117+
return len(metadata), None
118+
except Exception as e:
119+
log.warning(f"Failed to insert data into Clickhouse table ({self.table_name}), error: {e}")
120+
return 0, e
121+
122+
def search_embedding(
123+
self,
124+
query: list[float],
125+
k: int = 100,
126+
filters: dict | None = None,
127+
timeout: int | None = None,
128+
) -> list[int]:
129+
assert self.conn is not None, "Connection is not initialized"
130+
131+
index_param = self.case_config.index_param()
132+
search_param = self.case_config.search_param()
133+
134+
if filters:
135+
gt = filters.get("id")
136+
filterSql = (f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score '
137+
f'FROM {self.db_config["dbname"]}.{self.table_name} '
138+
f'WHERE id > {gt} '
139+
f'ORDER BY score LIMIT {k};'
140+
)
141+
result = self.conn.query(filterSql).result_rows
142+
return [int(row[0]) for row in result]
143+
else:
144+
selectSql = (f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score '
145+
f'FROM {self.db_config["dbname"]}.{self.table_name} '
146+
f'ORDER BY score LIMIT {k};'
147+
)
148+
result = self.conn.query(selectSql).result_rows
149+
return [int(row[0]) for row in result]
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from typing import TypedDict
2+
from pydantic import BaseModel, SecretStr
3+
from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
4+
5+
class ClickhouseConfig(DBConfig):
6+
user_name: str = "clickhouse"
7+
password: SecretStr
8+
host: str = "localhost"
9+
port: int = 8123
10+
db_name: str = "default"
11+
12+
def to_dict(self) -> dict:
13+
pwd_str = self.password.get_secret_value()
14+
return {
15+
"host": self.host,
16+
"port": self.port,
17+
"dbname": self.db_name,
18+
"user": self.user_name,
19+
"password": pwd_str
20+
}
21+
22+
23+
class ClickhouseIndexConfig(BaseModel):
24+
25+
metric_type: MetricType | None = None
26+
27+
def parse_metric(self) -> str:
28+
if not self.metric_type:
29+
return ""
30+
return self.metric_type.value
31+
32+
def parse_metric_str(self) -> str:
33+
if self.metric_type == MetricType.L2:
34+
return "L2Distance"
35+
elif self.metric_type == MetricType.COSINE:
36+
return "cosineDistance"
37+
38+
39+
class ClickhouseHNSWConfig(ClickhouseIndexConfig, DBCaseConfig):
40+
M: int | None
41+
efConstruction: int | None
42+
ef: int | None = None
43+
index: IndexType = IndexType.HNSW
44+
45+
def index_param(self) -> dict:
46+
return {
47+
"metric_type": self.parse_metric_str(),
48+
"index_type": self.index.value,
49+
"params": {"M": self.M, "efConstruction": self.efConstruction},
50+
}
51+
52+
def search_param(self) -> dict:
53+
return {
54+
"metric_type": self.parse_metric_str(),
55+
"params": {"ef": self.ef},
56+
}

vectordb_bench/cli/vectordbbench.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ..backend.clients.weaviate_cloud.cli import Weaviate
1313
from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex
1414
from ..backend.clients.tidb.cli import TiDB
15+
from ..backend.clients.clickhouse.cli import Clickhouse
1516
from .cli import cli
1617

1718
cli.add_command(PgVectorHNSW)
@@ -29,6 +30,7 @@
2930
cli.add_command(AlloyDBScaNN)
3031
cli.add_command(MariaDBHNSW)
3132
cli.add_command(TiDB)
33+
cli.add_command(Clickhouse)
3234

3335

3436
if __name__ == "__main__":

0 commit comments

Comments
 (0)