Skip to content

Commit 8fdfdc3

Browse files
reformat all
Signed-off-by: min.tian <min.tian.cn@gmail.com>
1 parent f0e8367 commit 8fdfdc3

24 files changed

Lines changed: 226 additions & 222 deletions

File tree

vectordb_bench/backend/clients/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901
158158
from .test.test import Test
159159

160160
return Test
161-
161+
162162
if self == DB.Vespa:
163163
from .vespa.vespa import Vespa
164164

@@ -279,17 +279,16 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901
279279
from .test.config import TestConfig
280280

281281
return TestConfig
282-
282+
283283
if self == DB.Vespa:
284284
from .vespa.config import VespaConfig
285285

286286
return VespaConfig
287287

288-
289288
msg = f"Unknown DB: {self.name}"
290289
raise ValueError(msg)
291290

292-
def case_config_cls( # noqa: PLR0911
291+
def case_config_cls( # noqa: C901, PLR0911, PLR0912
293292
self,
294293
index_type: IndexType | None = None,
295294
) -> type[DBCaseConfig]:
@@ -377,7 +376,7 @@ def case_config_cls( # noqa: PLR0911
377376
from .tidb.config import TiDBIndexConfig
378377

379378
return TiDBIndexConfig
380-
379+
381380
if self == DB.Vespa:
382381
from .vespa.config import VespaHNSWConfig
383382

vectordb_bench/backend/clients/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def insert_embeddings(
162162
embeddings: list[list[float]],
163163
metadata: list[int],
164164
**kwargs,
165-
) -> (int, Exception):
165+
) -> tuple[int, Exception]:
166166
"""Insert the embeddings to the vector database. The default number of embeddings for
167167
each insert_embeddings is 5000.
168168

vectordb_bench/backend/clients/chroma/chroma.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def insert_embeddings(
6565
embeddings: list[list[float]],
6666
metadata: list[int],
6767
**kwargs: Any,
68-
) -> (int, Exception):
68+
) -> tuple[int, Exception]:
6969
"""Insert embeddings into the database.
7070
7171
Args:
@@ -74,7 +74,7 @@ def insert_embeddings(
7474
kwargs: other arguments
7575
7676
Returns:
77-
(int, Exception): number of embeddings inserted and exception if any
77+
tuple[int, Exception]: number of embeddings inserted and exception if any
7878
"""
7979
ids = [str(i) for i in metadata]
8080
metadata = [{"id": int(i)} for i in metadata]

vectordb_bench/backend/clients/clickhouse/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class ClickhouseTypedDict(TypedDict):
1818
password: Annotated[str, click.option("--password", type=str, help="DB password")]
1919
host: Annotated[str, click.option("--host", type=str, help="DB host", required=True)]
2020
port: Annotated[int, click.option("--port", type=int, default=8123, help="DB Port")]
21-
user: Annotated[int, click.option("--user", type=str, default='clickhouse', help="DB user")]
21+
user: Annotated[int, click.option("--user", type=str, default="clickhouse", help="DB user")]
2222
ssl: Annotated[
2323
bool,
2424
click.option(

vectordb_bench/backend/clients/clickhouse/clickhouse.py

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
"""Wrapper around the Clickhouse vector database over VectorDB"""
22

3-
import io
43
import logging
54
from contextlib import contextmanager
65
from typing import Any
6+
77
import clickhouse_connect
8-
import numpy as np
98

10-
from ..api import VectorDB, DBCaseConfig
9+
from ..api import DBCaseConfig, VectorDB
1110

1211
log = logging.getLogger(__name__)
1312

13+
1414
class Clickhouse(VectorDB):
1515
"""Use SQLAlchemy instructions"""
16+
1617
def __init__(
1718
self,
1819
dim: int,
@@ -32,12 +33,13 @@ def __init__(
3233
self._vector_field = "embedding"
3334

3435
# construct basic units
35-
self.conn = clickhouse_connect.get_client(
36-
host=self.db_config["host"],
37-
port=self.db_config["port"],
38-
username=self.db_config["user"],
39-
password=self.db_config["password"],
40-
database=self.db_config["dbname"])
36+
self.conn = clickhouse_connect.get_client(
37+
host=self.db_config["host"],
38+
port=self.db_config["port"],
39+
username=self.db_config["user"],
40+
password=self.db_config["password"],
41+
database=self.db_config["dbname"],
42+
)
4143

4244
if drop_old:
4345
log.info(f"Clickhouse client drop table : {self.table_name}")
@@ -48,20 +50,21 @@ def __init__(
4850
self.conn = None
4951

5052
@contextmanager
51-
def init(self) -> None:
53+
def init(self):
5254
"""
5355
Examples:
5456
>>> with self.init():
5557
>>> self.insert_embeddings()
5658
>>> self.search_embedding()
5759
"""
5860

59-
self.conn = clickhouse_connect.get_client(
60-
host=self.db_config["host"],
61-
port=self.db_config["port"],
62-
username=self.db_config["user"],
63-
password=self.db_config["password"],
64-
database=self.db_config["dbname"])
61+
self.conn = clickhouse_connect.get_client(
62+
host=self.db_config["host"],
63+
port=self.db_config["port"],
64+
username=self.db_config["user"],
65+
password=self.db_config["password"],
66+
database=self.db_config["dbname"],
67+
)
6568

6669
try:
6770
yield
@@ -85,9 +88,7 @@ def _create_table(self, dim: int):
8588
)
8689

8790
except Exception as e:
88-
log.warning(
89-
f"Failed to create Clickhouse table: {self.table_name} error: {e}"
90-
)
91+
log.warning(f"Failed to create Clickhouse table: {self.table_name} error: {e}")
9192
raise e from None
9293

9394
def ready_to_load(self):
@@ -104,16 +105,20 @@ def insert_embeddings(
104105
embeddings: list[list[float]],
105106
metadata: list[int],
106107
**kwargs: Any,
107-
) -> (int, Exception):
108+
) -> tuple[int, Exception]:
108109
assert self.conn is not None, "Connection is not initialized"
109110

110111
try:
111112
# do not iterate for bulk insert
112113
items = [metadata, embeddings]
113114

114-
self.conn.insert(table=self.table_name, data=items,
115-
column_names=['id', 'embedding'], column_type_names=['UInt32', 'Array(Float64)'],
116-
column_oriented=True)
115+
self.conn.insert(
116+
table=self.table_name,
117+
data=items,
118+
column_names=["id", "embedding"],
119+
column_type_names=["UInt32", "Array(Float64)"],
120+
column_oriented=True,
121+
)
117122
return len(metadata), None
118123
except Exception as e:
119124
log.warning(f"Failed to insert data into Clickhouse table ({self.table_name}), error: {e}")
@@ -128,22 +133,24 @@ def search_embedding(
128133
) -> list[int]:
129134
assert self.conn is not None, "Connection is not initialized"
130135

131-
index_param = self.case_config.index_param()
136+
index_param = self.case_config.index_param() # noqa: F841
132137
search_param = self.case_config.search_param()
133138

134139
if filters:
135140
gt = filters.get("id")
136-
filterSql = (f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score '
137-
f'FROM {self.db_config["dbname"]}.{self.table_name} '
138-
f'WHERE id > {gt} '
139-
f'ORDER BY score LIMIT {k};'
140-
)
141-
result = self.conn.query(filterSql).result_rows
141+
filter_sql = (
142+
f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' # noqa: S608
143+
f'FROM {self.db_config["dbname"]}.{self.table_name} '
144+
f"WHERE id > {gt} "
145+
f"ORDER BY score LIMIT {k};"
146+
)
147+
result = self.conn.query(filter_sql).result_rows
142148
return [int(row[0]) for row in result]
143-
else:
144-
selectSql = (f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score '
145-
f'FROM {self.db_config["dbname"]}.{self.table_name} '
146-
f'ORDER BY score LIMIT {k};'
147-
)
148-
result = self.conn.query(selectSql).result_rows
149+
else: # noqa: RET505
150+
select_sql = (
151+
f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' # noqa: S608
152+
f'FROM {self.db_config["dbname"]}.{self.table_name} '
153+
f"ORDER BY score LIMIT {k};"
154+
)
155+
result = self.conn.query(select_sql).result_rows
149156
return [int(row[0]) for row in result]

vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def insert_embeddings(
8181
embeddings: Iterable[list[float]],
8282
metadata: list[int],
8383
**kwargs,
84-
) -> (int, Exception):
84+
) -> tuple[int, Exception]:
8585
"""Insert the embeddings to the elasticsearch."""
8686
assert self.client is not None, "should self.init() first"
8787

vectordb_bench/backend/clients/mariadb/cli.py

Lines changed: 58 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Annotated, Optional, Unpack
1+
from typing import Annotated, Unpack
22

33
import click
44
from pydantic import SecretStr
@@ -15,68 +15,84 @@
1515

1616
class MariaDBTypedDict(CommonTypedDict):
1717
user_name: Annotated[
18-
str, click.option("--username",
19-
type=str,
20-
help="Username",
21-
required=True,
22-
),
18+
str,
19+
click.option(
20+
"--username",
21+
type=str,
22+
help="Username",
23+
required=True,
24+
),
2325
]
2426
password: Annotated[
25-
str, click.option("--password",
26-
type=str,
27-
help="Password",
28-
required=True,
29-
),
27+
str,
28+
click.option(
29+
"--password",
30+
type=str,
31+
help="Password",
32+
required=True,
33+
),
3034
]
3135

3236
host: Annotated[
33-
str, click.option("--host",
34-
type=str,
35-
help="Db host",
36-
default="127.0.0.1",
37-
),
37+
str,
38+
click.option(
39+
"--host",
40+
type=str,
41+
help="Db host",
42+
default="127.0.0.1",
43+
),
3844
]
3945

4046
port: Annotated[
41-
int, click.option("--port",
42-
type=int,
43-
default=3306,
44-
help="Db Port",
45-
),
47+
int,
48+
click.option(
49+
"--port",
50+
type=int,
51+
default=3306,
52+
help="Db Port",
53+
),
4654
]
4755

4856
storage_engine: Annotated[
49-
int, click.option("--storage-engine",
50-
type=click.Choice(["InnoDB", "MyISAM"]),
51-
help="DB storage engine",
52-
required=True,
53-
),
57+
int,
58+
click.option(
59+
"--storage-engine",
60+
type=click.Choice(["InnoDB", "MyISAM"]),
61+
help="DB storage engine",
62+
required=True,
63+
),
5464
]
5565

66+
5667
class MariaDBHNSWTypedDict(MariaDBTypedDict):
57-
...
5868
m: Annotated[
59-
Optional[int], click.option("--m",
60-
type=int,
61-
help="M parameter in MHNSW vector indexing",
62-
required=False,
63-
),
69+
int | None,
70+
click.option(
71+
"--m",
72+
type=int,
73+
help="M parameter in MHNSW vector indexing",
74+
required=False,
75+
),
6476
]
6577

6678
ef_search: Annotated[
67-
Optional[int], click.option("--ef-search",
68-
type=int,
69-
help="MariaDB system variable mhnsw_min_limit",
70-
required=False,
71-
),
79+
int | None,
80+
click.option(
81+
"--ef-search",
82+
type=int,
83+
help="MariaDB system variable mhnsw_min_limit",
84+
required=False,
85+
),
7286
]
7387

7488
max_cache_size: Annotated[
75-
Optional[int], click.option("--max-cache-size",
76-
type=int,
77-
help="MariaDB system variable mhnsw_max_cache_size",
78-
required=False,
79-
),
89+
int | None,
90+
click.option(
91+
"--max-cache-size",
92+
type=int,
93+
help="MariaDB system variable mhnsw_max_cache_size",
94+
required=False,
95+
),
8096
]
8197

8298

vectordb_bench/backend/clients/mariadb/config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
class MariaDBConfigDict(TypedDict):
99
"""These keys will be directly used as kwargs in mariadb connection string,
10-
so the names must match exactly mariadb API"""
10+
so the names must match exactly mariadb API"""
1111

1212
user: str
1313
password: str
@@ -44,6 +44,7 @@ def parse_metric(self) -> str:
4444
msg = f"Metric type {self.metric_type} is not supported!"
4545
raise ValueError(msg)
4646

47+
4748
class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig):
4849
M: int | None
4950
ef_search: int | None
@@ -68,7 +69,5 @@ def search_param(self) -> dict:
6869

6970

7071
_mariadb_case_config = {
71-
IndexType.HNSW: MariaDBHNSWConfig,
72+
IndexType.HNSW: MariaDBHNSWConfig,
7273
}
73-
74-

0 commit comments

Comments
 (0)