Skip to content

Commit a8625c1

Browse files
alwayslove2013XuanYang-cn
authored andcommitted
upgrade ruff / black, reformat all
Signed-off-by: min.tian <min.tian.cn@gmail.com>
1 parent e0f35a4 commit a8625c1

6 files changed

Lines changed: 81 additions & 84 deletions

File tree

vectordb_bench/backend/clients/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ class DB(Enum):
4747
Clickhouse = "Clickhouse"
4848
Vespa = "Vespa"
4949
LanceDB = "LanceDB"
50-
5150

5251
@property
5352
def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915
@@ -76,10 +75,10 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915
7675
from .qdrant_cloud.qdrant_cloud import QdrantCloud
7776

7877
return QdrantCloud
79-
78+
8079
if self == DB.QdrantLocal:
8180
from .qdrant_local.qdrant_local import QdrantLocal
82-
81+
8382
return QdrantLocal
8483

8584
if self == DB.WeaviateCloud:
@@ -207,10 +206,12 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901, PLR0915
207206
from .qdrant_cloud.config import QdrantConfig
208207

209208
return QdrantConfig
210-
209+
211210
if self == DB.QdrantLocal:
212211
from .qdrant_local.config import QdrantLocalConfig
213212

213+
return QdrantLocalConfig
214+
214215
if self == DB.WeaviateCloud:
215216
from .weaviate_cloud.config import WeaviateConfig
216217

@@ -332,10 +333,10 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912
332333
from .qdrant_cloud.config import QdrantIndexConfig
333334

334335
return QdrantIndexConfig
335-
336+
336337
if self == DB.QdrantLocal:
337338
from .qdrant_local.config import QdrantLocalIndexConfig
338-
339+
339340
return QdrantLocalIndexConfig
340341

341342
if self == DB.WeaviateCloud:

vectordb_bench/backend/clients/qdrant_local/cli.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Annotated, TypedDict, Unpack
1+
from typing import Annotated, Unpack
22

33
import click
44
from pydantic import SecretStr
@@ -11,7 +11,6 @@
1111
run,
1212
)
1313

14-
1514
DBTYPE = DB.QdrantLocal
1615

1716

@@ -22,39 +21,35 @@ class QdrantLocalTypedDict(CommonTypedDict):
2221
]
2322
on_disk: Annotated[
2423
bool,
25-
click.option(
26-
"--on-disk", type=bool, default=False, help="Store the vectors and the HNSW index on disk"
27-
),
24+
click.option("--on-disk", type=bool, default=False, help="Store the vectors and the HNSW index on disk"),
2825
]
2926
m: Annotated[
3027
int,
31-
click.option(
32-
"--m", type=int, default=16, help="HNSW index parameter m, set 0 to disable the index"
33-
),
28+
click.option("--m", type=int, default=16, help="HNSW index parameter m, set 0 to disable the index"),
3429
]
3530
ef_construct: Annotated[
3631
int,
37-
click.option(
38-
"--ef-construct", type=int, default=200, help="HNSW index parameter ef_construct"
39-
),
32+
click.option("--ef-construct", type=int, default=200, help="HNSW index parameter ef_construct"),
4033
]
4134
hnsw_ef: Annotated[
4235
int,
4336
click.option(
44-
"--hnsw-ef", type=int, default=0, help="HNSW index parameter hnsw_ef, set 0 to use ef_construct for search",
37+
"--hnsw-ef",
38+
type=int,
39+
default=0,
40+
help="HNSW index parameter hnsw_ef, set 0 to use ef_construct for search",
4541
),
4642
]
4743

44+
4845
@cli.command()
4946
@click_parameter_decorators_from_typed_dict(QdrantLocalTypedDict)
5047
def QdrantLocal(**parameters: Unpack[QdrantLocalTypedDict]):
5148
from .config import QdrantLocalConfig, QdrantLocalIndexConfig
5249

5350
run(
5451
db=DBTYPE,
55-
db_config=QdrantLocalConfig(
56-
url=SecretStr(parameters["url"])
57-
),
52+
db_config=QdrantLocalConfig(url=SecretStr(parameters["url"])),
5853
db_case_config=QdrantLocalIndexConfig(
5954
on_disk=parameters["on_disk"],
6055
m=parameters["m"],
Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from pydantic import BaseModel, SecretStr
22

3-
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
3+
from ..api import DBCaseConfig, DBConfig, MetricType
4+
45

56
class QdrantLocalConfig(DBConfig):
67
url: SecretStr
7-
8+
89
def to_dict(self) -> dict:
910
return {
1011
"url": self.url.get_secret_value(),
@@ -17,7 +18,7 @@ class QdrantLocalIndexConfig(BaseModel, DBCaseConfig):
1718
ef_construct: int
1819
hnsw_ef: int | None = 0
1920
on_disk: bool | None = False
20-
21+
2122
def parse_metric(self) -> str:
2223
if self.metric_type == MetricType.L2:
2324
return "Euclid"
@@ -26,21 +27,21 @@ def parse_metric(self) -> str:
2627
return "Dot"
2728

2829
return "Cosine"
29-
30+
3031
def index_param(self) -> dict:
3132
return {
3233
"distance": self.parse_metric(),
3334
"m": self.m,
3435
"ef_construct": self.ef_construct,
3536
"on_disk": self.on_disk,
3637
}
37-
38+
3839
def search_param(self) -> dict:
3940
search_params = {
40-
"exact": False, # Force to use ANNs
41+
"exact": False, # Force to use ANNs
4142
}
42-
43+
4344
if self.hnsw_ef != 0:
4445
search_params["hnsw_ef"] = self.hnsw_ef
45-
46-
return search_params
46+
47+
return search_params

vectordb_bench/backend/clients/qdrant_local/qdrant_local.py

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,23 @@
2828
QDRANT_BATCH_SIZE = 100
2929

3030

31-
def qdrant_collection_exists(client, collection_name: str) -> bool:
31+
def qdrant_collection_exists(client: QdrantClient, collection_name: str) -> bool:
3232
collection_exists = True
33-
33+
3434
try:
3535
client.get_collection(collection_name)
36-
except Exception as e:
36+
except Exception:
3737
collection_exists = False
38-
38+
3939
return collection_exists
40-
40+
41+
4142
class QdrantLocal(VectorDB):
4243
def __init__(
4344
self,
4445
dim: int,
4546
db_config: dict,
46-
db_case_config: dict,
47+
db_case_config: QdrantLocalIndexConfig,
4748
collection_name: str = "QdrantLocalCollection",
4849
drop_old: bool = False,
4950
name: str = "QdrantLocal",
@@ -56,26 +57,26 @@ def __init__(
5657
self.search_parameter = self.case_config.search_param()
5758
self.collection_name = collection_name
5859
self.client = None
59-
60+
6061
self._primary_field = "pk"
6162
self._vector_field = "vector"
62-
63+
6364
client = QdrantClient(**self.db_config)
64-
65+
6566
# Lets just print the parameters here for double check
6667
log.info(f"Case config: {self.case_config.index_param()}")
6768
log.info(f"Search parameter: {self.search_parameter}")
68-
69+
6970
if drop_old and qdrant_collection_exists(client, self.collection_name):
7071
log.info(f"{self.name} client drop_old collection: {self.collection_name}")
7172
client.delete_collection(self.collection_name)
72-
73+
7374
if not qdrant_collection_exists(client, self.collection_name):
7475
log.info(f"{self.name} create collection: {self.collection_name}")
7576
self._create_collection(dim, client)
7677

7778
client = None
78-
79+
7980
@contextmanager
8081
def init(self):
8182
"""
@@ -89,11 +90,15 @@ def init(self):
8990
yield
9091
self.client = None
9192
del self.client
92-
93+
9394
def _create_collection(self, dim: int, qdrant_client: QdrantClient):
9495
log.info(f"Create collection: {self.collection_name}")
95-
log.info(f"Index parameters: m={self.case_config.index_param()['m']}, ef_construct={self.case_config.index_param()['ef_construct']}, on_disk={self.case_config.index_param()['on_disk']}")
96-
96+
log.info(
97+
f"Index parameters: m={self.case_config.index_param()['m']}, "
98+
f"ef_construct={self.case_config.index_param()['ef_construct']}, "
99+
f"on_disk={self.case_config.index_param()['on_disk']}"
100+
)
101+
97102
# If the on_disk is true, we enable both on disk index and vectors.
98103
try:
99104
qdrant_client.create_collection(
@@ -104,10 +109,10 @@ def _create_collection(self, dim: int, qdrant_client: QdrantClient):
104109
on_disk=self.case_config.index_param()["on_disk"],
105110
),
106111
hnsw_config=HnswConfigDiff(
107-
m = self.case_config.index_param()["m"],
112+
m=self.case_config.index_param()["m"],
108113
ef_construct=self.case_config.index_param()["ef_construct"],
109114
on_disk=self.case_config.index_param()["on_disk"],
110-
)
115+
),
111116
)
112117

113118
qdrant_client.create_payload_index(
@@ -121,7 +126,7 @@ def _create_collection(self, dim: int, qdrant_client: QdrantClient):
121126
return
122127
log.warning(f"Failed to create collection: {self.collection_name} error: {e}")
123128
raise e from None
124-
129+
125130
def optimize(self, data_size: int | None = None):
126131
assert self.client, "Please call self.init() before"
127132
# wait for vectors to be fully indexed
@@ -139,11 +144,11 @@ def optimize(self, data_size: int | None = None):
139144
)
140145
log.info(msg)
141146
return
142-
147+
143148
except Exception as e:
144149
log.warning(f"QdrantCloud ready to search error: {e}")
145150
raise e from None
146-
151+
147152
def insert_embeddings(
148153
self,
149154
embeddings: Iterable[list[float]],
@@ -163,7 +168,7 @@ def insert_embeddings(
163168
assert self.client is not None
164169
assert len(embeddings) == len(metadata)
165170
insert_count = 0
166-
171+
167172
# disable indexing for quick insertion
168173
self.client.update_collection(
169174
collection_name=self.collection_name,
@@ -185,13 +190,13 @@ def insert_embeddings(
185190
collection_name=self.collection_name,
186191
optimizer_config=OptimizersConfigDiff(indexing_threshold=100),
187192
)
188-
193+
189194
except Exception as e:
190195
log.info(f"Failed to insert data, {e}")
191196
return insert_count, e
192197
else:
193198
return insert_count, None
194-
199+
195200
def search_embedding(
196201
self,
197202
query: list[float],
@@ -203,7 +208,7 @@ def search_embedding(
203208
Should call self.init() first.
204209
"""
205210
assert self.client is not None
206-
211+
207212
f = None
208213
if filters:
209214
f = Filter(
@@ -215,17 +220,13 @@ def search_embedding(
215220
),
216221
),
217222
],
218-
)
219-
res = (
220-
self.client.query_points(
221-
collection_name=self.collection_name,
222-
query=query,
223-
limit=k,
224-
query_filter=f,
225-
search_params=SearchParams(**self.search_parameter),
226-
227-
).points
228-
)
229-
230-
return [result.id for result in res]
223+
)
224+
res = self.client.query_points(
225+
collection_name=self.collection_name,
226+
query=query,
227+
limit=k,
228+
query_filter=f,
229+
search_params=SearchParams(**self.search_parameter),
230+
).points
231231

232+
return [result.id for result in res]

vectordb_bench/backend/clients/weaviate_cloud/cli.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,33 +15,32 @@
1515
class WeaviateTypedDict(CommonTypedDict):
1616
api_key: Annotated[
1717
str,
18-
click.option("--api-key", type=str, help="Weaviate api key", required=False, default=''),
18+
click.option("--api-key", type=str, help="Weaviate api key", required=False, default=""),
1919
]
2020
url: Annotated[
2121
str,
2222
click.option("--url", type=str, help="Weaviate url", required=True),
2323
]
2424
no_auth: Annotated[
2525
bool,
26-
click.option("--no-auth", is_flag=True, help="Do not use api-key, set it to true if you are using a local setup. Default is False.", default=False),
26+
click.option(
27+
"--no-auth",
28+
is_flag=True,
29+
help="Do not use api-key, set it to true if you are using a local setup. Default is False.",
30+
default=False,
31+
),
2732
]
2833
m: Annotated[
2934
int,
30-
click.option(
31-
"--m", type=int, default=16, help="HNSW index parameter m."
32-
),
35+
click.option("--m", type=int, default=16, help="HNSW index parameter m."),
3336
]
3437
ef_construct: Annotated[
3538
int,
36-
click.option(
37-
"--ef-construction", type=int, default=256, help="HNSW index parameter ef_construction"
38-
),
39+
click.option("--ef-construction", type=int, default=256, help="HNSW index parameter ef_construction"),
3940
]
4041
ef: Annotated[
4142
int,
42-
click.option(
43-
"--ef", type=int, default=256, help="HNSW index parameter ef for search"
44-
),
43+
click.option("--ef", type=int, default=256, help="HNSW index parameter ef for search"),
4544
]
4645

4746

@@ -54,7 +53,7 @@ def Weaviate(**parameters: Unpack[WeaviateTypedDict]):
5453
db=DB.WeaviateCloud,
5554
db_config=WeaviateConfig(
5655
db_label=parameters["db_label"],
57-
api_key=SecretStr(parameters["api_key"]) if parameters["api_key"] != '' else SecretStr("-"),
56+
api_key=SecretStr(parameters["api_key"]) if parameters["api_key"] != "" else SecretStr("-"),
5857
url=SecretStr(parameters["url"]),
5958
no_auth=parameters["no_auth"],
6059
),

0 commit comments

Comments
 (0)