Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 123 additions & 47 deletions vectordb_bench/backend/clients/clickhouse/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
from typing import Any

import clickhouse_connect
from clickhouse_connect.driver import Client

from ..api import DBCaseConfig, VectorDB
from .. import IndexType
from ..api import VectorDB
from .config import ClickhouseConfigDict, ClickhouseIndexConfig

log = logging.getLogger(__name__)

Expand All @@ -17,8 +20,8 @@ class Clickhouse(VectorDB):
def __init__(
self,
dim: int,
db_config: dict,
db_case_config: DBCaseConfig,
db_config: ClickhouseConfigDict,
db_case_config: ClickhouseIndexConfig,
collection_name: str = "CHVectorCollection",
drop_old: bool = False,
**kwargs,
Expand All @@ -28,84 +31,130 @@ def __init__(
self.table_name = collection_name
self.dim = dim

self.index_param = self.case_config.index_param()
self.search_param = self.case_config.search_param()
self.session_param = self.case_config.session_param()

self._index_name = "clickhouse_index"
self._primary_field = "id"
self._vector_field = "embedding"

# construct basic units
self.conn = clickhouse_connect.get_client(
host=self.db_config["host"],
port=self.db_config["port"],
username=self.db_config["user"],
password=self.db_config["password"],
database=self.db_config["dbname"],
)
self.conn = self._create_connection(**self.db_config, settings=self.session_param)

if drop_old:
log.info(f"Clickhouse client drop table : {self.table_name}")
self._drop_table()
self._create_table(dim)
if self.case_config.create_index_before_load:
self._create_index()

self.conn.close()
self.conn = None

@contextmanager
def init(self):
def init(self) -> None:
"""
Examples:
>>> with self.init():
>>> self.insert_embeddings()
>>> self.search_embedding()
"""

self.conn = clickhouse_connect.get_client(
host=self.db_config["host"],
port=self.db_config["port"],
username=self.db_config["user"],
password=self.db_config["password"],
database=self.db_config["dbname"],
)
self.conn = self._create_connection(**self.db_config, settings=self.session_param)

try:
yield
finally:
self.conn.close()
self.conn = None

def _create_connection(self, settings: dict | None, **kwargs) -> Client:
return clickhouse_connect.get_client(**self.db_config, settings=settings)

def _drop_index(self):
assert self.conn is not None, "Connection is not initialized"
try:
self.conn.command(
f'ALTER TABLE {self.db_config["database"]}.{self.table_name} DROP INDEX {self._index_name}'
)
except Exception as e:
log.warning(f"Failed to drop index on table {self.db_config['database']}.{self.table_name}: {e}")
raise e from None

def _drop_table(self):
assert self.conn is not None, "Connection is not initialized"

self.conn.command(f'DROP TABLE IF EXISTS {self.db_config["dbname"]}.{self.table_name}')
try:
self.conn.command(f'DROP TABLE IF EXISTS {self.db_config["database"]}.{self.table_name}')
except Exception as e:
log.warning(f"Failed to drop table {self.db_config['database']}.{self.table_name}: {e}")
raise e from None

def _perfomance_tuning(self):
self.conn.command("SET materialize_skip_indexes_on_insert = 1")

def _create_index(self):
assert self.conn is not None, "Connection is not initialized"
try:
if self.index_param["index_type"] == IndexType.HNSW.value:
if (
self.index_param["quantization"]
and self.index_param["params"]["M"]
and self.index_param["params"]["efConstruction"]
):
query = f"""
ALTER TABLE {self.db_config["database"]}.{self.table_name}
ADD INDEX {self._index_name} {self._vector_field}
TYPE vector_similarity('hnsw', '{self.index_param["metric_type"]}',
'{self.index_param["quantization"]}',
{self.index_param["params"]["M"]}, {self.index_param["params"]["efConstruction"]})
GRANULARITY {self.index_param["granularity"]}
"""
else:
query = f"""
ALTER TABLE {self.db_config["database"]}.{self.table_name}
ADD INDEX {self._index_name} {self._vector_field}
TYPE vector_similarity('hnsw', '{self.index_param["metric_type"]}')
GRANULARITY {self.index_param["granularity"]}
"""
self.conn.command(cmd=query)
else:
log.warning("HNSW is only avaliable method in clickhouse now")
except Exception as e:
log.warning(f"Failed to create Clickhouse vector index on table: {self.table_name} error: {e}")
raise e from None

def _create_table(self, dim: int):
assert self.conn is not None, "Connection is not initialized"

try:
# create table
self.conn.command(
f'CREATE TABLE IF NOT EXISTS {self.db_config["dbname"]}.{self.table_name} \
(id UInt32, embedding Array(Float64)) ENGINE = MergeTree() ORDER BY id;'
f'CREATE TABLE IF NOT EXISTS {self.db_config["database"]}.{self.table_name} '
f"({self._primary_field} UInt32, "
f'{self._vector_field} Array({self.index_param["vector_data_type"]}) CODEC(NONE), '
f"CONSTRAINT same_length CHECK length(embedding) = {dim}) "
f"ENGINE = MergeTree() "
f"ORDER BY {self._primary_field}"
)

except Exception as e:
log.warning(f"Failed to create Clickhouse table: {self.table_name} error: {e}")
raise e from None

def ready_to_load(self):
pass

def optimize(self, data_size: int | None = None):
pass

def ready_to_search(self):
def _post_insert(self):
pass

def insert_embeddings(
self,
embeddings: list[list[float]],
metadata: list[int],
**kwargs: Any,
) -> tuple[int, Exception]:
) -> (int, Exception):
assert self.conn is not None, "Connection is not initialized"

try:
Expand All @@ -116,7 +165,7 @@ def insert_embeddings(
table=self.table_name,
data=items,
column_names=["id", "embedding"],
column_type_names=["UInt32", "Array(Float64)"],
column_type_names=["UInt32", f'Array({self.index_param["vector_data_type"]})'],
column_oriented=True,
)
return len(metadata), None
Expand All @@ -132,25 +181,52 @@ def search_embedding(
timeout: int | None = None,
) -> list[int]:
assert self.conn is not None, "Connection is not initialized"

index_param = self.case_config.index_param() # noqa: F841
search_param = self.case_config.search_param()

if filters:
gt = filters.get("id")
filter_sql = (
f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' # noqa: S608
f'FROM {self.db_config["dbname"]}.{self.table_name} '
f"WHERE id > {gt} "
f"ORDER BY score LIMIT {k};"
)
result = self.conn.query(filter_sql).result_rows
parameters = {
"primary_field": self._primary_field,
"vector_field": self._vector_field,
"schema": self.db_config["database"],
"table": self.table_name,
"gt": filters.get("id"),
"k": k,
"metric_type": self.search_param["metric_type"],
"query": query,
}
if self.case_config.metric_type == "COSINE":
if filters:
result = self.conn.query(
"SELECT {primary_field:Identifier}, {vector_field:Identifier} "
"FROM {schema:Identifier}.{table:Identifier} "
"WHERE {primary_field:Identifier} > {gt:UInt32} "
"ORDER BY cosineDistance(embedding,{query:Array(Float64)}) "
"LIMIT {k:UInt32}",
parameters=parameters,
).result_rows
return [int(row[0]) for row in result]

result = self.conn.query(
"SELECT {primary_field:Identifier}, {vector_field:Identifier} "
"FROM {schema:Identifier}.{table:Identifier} "
"ORDER BY cosineDistance(embedding,{query:Array(Float64)}) "
"LIMIT {k:UInt32}",
parameters=parameters,
).result_rows
return [int(row[0]) for row in result]
else: # noqa: RET505
select_sql = (
f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' # noqa: S608
f'FROM {self.db_config["dbname"]}.{self.table_name} '
f"ORDER BY score LIMIT {k};"
)
result = self.conn.query(select_sql).result_rows
if filters:
result = self.conn.query(
"SELECT {primary_field:Identifier}, {vector_field:Identifier} "
"FROM {schema:Identifier}.{table:Identifier} "
"WHERE {primary_field:Identifier} > {gt:UInt32} "
"ORDER BY L2Distance(embedding,{query:Array(Float64)}) "
"LIMIT {k:UInt32}",
parameters=parameters,
).result_rows
return [int(row[0]) for row in result]

result = self.conn.query(
"SELECT {primary_field:Identifier}, {vector_field:Identifier} "
"FROM {schema:Identifier}.{table:Identifier} "
"ORDER BY L2Distance(embedding,{query:Array(Float64)}) "
"LIMIT {k:UInt32}",
parameters=parameters,
).result_rows
return [int(row[0]) for row in result]
49 changes: 39 additions & 10 deletions vectordb_bench/backend/clients/clickhouse/config.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,46 @@
from abc import abstractmethod
from typing import TypedDict

from pydantic import BaseModel, SecretStr

from ..api import DBCaseConfig, DBConfig, IndexType, MetricType


class ClickhouseConfigDict(TypedDict):
user: str
password: str
host: str
port: int
database: str
secure: bool


class ClickhouseConfig(DBConfig):
user_name: str = "clickhouse"
password: SecretStr
host: str = "localhost"
port: int = 8123
db_name: str = "default"
secure: bool = False

def to_dict(self) -> dict:
def to_dict(self) -> ClickhouseConfigDict:
pwd_str = self.password.get_secret_value()
return {
"host": self.host,
"port": self.port,
"dbname": self.db_name,
"database": self.db_name,
"user": self.user_name,
"password": pwd_str,
"secure": self.secure,
}


class ClickhouseIndexConfig(BaseModel):
class ClickhouseIndexConfig(BaseModel, DBCaseConfig):

metric_type: MetricType | None = None
vector_data_type: str | None = "Float32" # Data type of vectors. Can be Float32 or Float64 or BFloat16
create_index_before_load: bool = True
create_index_after_load: bool = False

def parse_metric(self) -> str:
if not self.metric_type:
Expand All @@ -35,26 +52,38 @@ def parse_metric_str(self) -> str:
return "L2Distance"
if self.metric_type == MetricType.COSINE:
return "cosineDistance"
msg = f"Not Support for {self.metric_type}"
raise RuntimeError(msg)
return None
return "cosineDistance"

@abstractmethod
def session_param(self):
pass


class ClickhouseHNSWConfig(ClickhouseIndexConfig, DBCaseConfig):
M: int | None
efConstruction: int | None
class ClickhouseHNSWConfig(ClickhouseIndexConfig):
M: int | None # Default in clickhouse in 32
efConstruction: int | None # Default in clickhouse in 128
ef: int | None = None
index: IndexType = IndexType.HNSW
quantization: str | None = "bf16" # Default is bf16. Possible values are f64, f32, f16, bf16, or i8
granularity: int | None = 10_000_000 # Size of the index granules. By default, in CH it's equal 10.000.000

def index_param(self) -> dict:
return {
"vector_data_type": self.vector_data_type,
"metric_type": self.parse_metric_str(),
"index_type": self.index.value,
"quantization": self.quantization,
"granularity": self.granularity,
"params": {"M": self.M, "efConstruction": self.efConstruction},
}

def search_param(self) -> dict:
return {
"met˝ric_type": self.parse_metric_str(),
"metric_type": self.parse_metric_str(),
"params": {"ef": self.ef},
}

def session_param(self) -> dict:
return {
"allow_experimental_vector_similarity_index": 1,
}