Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class DB(Enum):
Clickhouse = "Clickhouse"
Vespa = "Vespa"
LanceDB = "LanceDB"


@property
def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915
Expand Down Expand Up @@ -76,10 +75,10 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915
from .qdrant_cloud.qdrant_cloud import QdrantCloud

return QdrantCloud

if self == DB.QdrantLocal:
from .qdrant_local.qdrant_local import QdrantLocal

return QdrantLocal

if self == DB.WeaviateCloud:
Expand Down Expand Up @@ -207,10 +206,12 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901, PLR0915
from .qdrant_cloud.config import QdrantConfig

return QdrantConfig

if self == DB.QdrantLocal:
from .qdrant_local.config import QdrantLocalConfig

return QdrantLocalConfig

if self == DB.WeaviateCloud:
from .weaviate_cloud.config import WeaviateConfig

Expand Down Expand Up @@ -332,10 +333,10 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912
from .qdrant_cloud.config import QdrantIndexConfig

return QdrantIndexConfig

if self == DB.QdrantLocal:
from .qdrant_local.config import QdrantLocalIndexConfig

return QdrantLocalIndexConfig

if self == DB.WeaviateCloud:
Expand Down
25 changes: 10 additions & 15 deletions vectordb_bench/backend/clients/qdrant_local/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, TypedDict, Unpack
from typing import Annotated, Unpack

import click
from pydantic import SecretStr
Expand All @@ -11,7 +11,6 @@
run,
)


DBTYPE = DB.QdrantLocal


Expand All @@ -22,39 +21,35 @@ class QdrantLocalTypedDict(CommonTypedDict):
]
on_disk: Annotated[
bool,
click.option(
"--on-disk", type=bool, default=False, help="Store the vectors and the HNSW index on disk"
),
click.option("--on-disk", type=bool, default=False, help="Store the vectors and the HNSW index on disk"),
]
m: Annotated[
int,
click.option(
"--m", type=int, default=16, help="HNSW index parameter m, set 0 to disable the index"
),
click.option("--m", type=int, default=16, help="HNSW index parameter m, set 0 to disable the index"),
]
ef_construct: Annotated[
int,
click.option(
"--ef-construct", type=int, default=200, help="HNSW index parameter ef_construct"
),
click.option("--ef-construct", type=int, default=200, help="HNSW index parameter ef_construct"),
]
hnsw_ef: Annotated[
int,
click.option(
"--hnsw-ef", type=int, default=0, help="HNSW index parameter hnsw_ef, set 0 to use ef_construct for search",
"--hnsw-ef",
type=int,
default=0,
help="HNSW index parameter hnsw_ef, set 0 to use ef_construct for search",
),
]


@cli.command()
@click_parameter_decorators_from_typed_dict(QdrantLocalTypedDict)
def QdrantLocal(**parameters: Unpack[QdrantLocalTypedDict]):
from .config import QdrantLocalConfig, QdrantLocalIndexConfig

run(
db=DBTYPE,
db_config=QdrantLocalConfig(
url=SecretStr(parameters["url"])
),
db_config=QdrantLocalConfig(url=SecretStr(parameters["url"])),
db_case_config=QdrantLocalIndexConfig(
on_disk=parameters["on_disk"],
m=parameters["m"],
Expand Down
19 changes: 10 additions & 9 deletions vectordb_bench/backend/clients/qdrant_local/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from pydantic import BaseModel, SecretStr

from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
from ..api import DBCaseConfig, DBConfig, MetricType


class QdrantLocalConfig(DBConfig):
url: SecretStr

def to_dict(self) -> dict:
return {
"url": self.url.get_secret_value(),
Expand All @@ -17,7 +18,7 @@ class QdrantLocalIndexConfig(BaseModel, DBCaseConfig):
ef_construct: int
hnsw_ef: int | None = 0
on_disk: bool | None = False

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "Euclid"
Expand All @@ -26,21 +27,21 @@ def parse_metric(self) -> str:
return "Dot"

return "Cosine"

def index_param(self) -> dict:
return {
"distance": self.parse_metric(),
"m": self.m,
"ef_construct": self.ef_construct,
"on_disk": self.on_disk,
}

def search_param(self) -> dict:
search_params = {
"exact": False, # Force to use ANNs
"exact": False, # Force to use ANNs
}

if self.hnsw_ef != 0:
search_params["hnsw_ef"] = self.hnsw_ef
return search_params

return search_params
75 changes: 38 additions & 37 deletions vectordb_bench/backend/clients/qdrant_local/qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,23 @@
QDRANT_BATCH_SIZE = 100


def qdrant_collection_exists(client, collection_name: str) -> bool:
def qdrant_collection_exists(client: QdrantClient, collection_name: str) -> bool:
collection_exists = True

try:
client.get_collection(collection_name)
except Exception as e:
except Exception:
collection_exists = False

return collection_exists



class QdrantLocal(VectorDB):
def __init__(
self,
dim: int,
db_config: dict,
db_case_config: dict,
db_case_config: QdrantLocalIndexConfig,
collection_name: str = "QdrantLocalCollection",
drop_old: bool = False,
name: str = "QdrantLocal",
Expand All @@ -56,26 +57,26 @@ def __init__(
self.search_parameter = self.case_config.search_param()
self.collection_name = collection_name
self.client = None

self._primary_field = "pk"
self._vector_field = "vector"

client = QdrantClient(**self.db_config)

# Lets just print the parameters here for double check
log.info(f"Case config: {self.case_config.index_param()}")
log.info(f"Search parameter: {self.search_parameter}")

if drop_old and qdrant_collection_exists(client, self.collection_name):
log.info(f"{self.name} client drop_old collection: {self.collection_name}")
client.delete_collection(self.collection_name)

if not qdrant_collection_exists(client, self.collection_name):
log.info(f"{self.name} create collection: {self.collection_name}")
self._create_collection(dim, client)

client = None

@contextmanager
def init(self):
"""
Expand All @@ -89,11 +90,15 @@ def init(self):
yield
self.client = None
del self.client

def _create_collection(self, dim: int, qdrant_client: QdrantClient):
log.info(f"Create collection: {self.collection_name}")
log.info(f"Index parameters: m={self.case_config.index_param()['m']}, ef_construct={self.case_config.index_param()['ef_construct']}, on_disk={self.case_config.index_param()['on_disk']}")

log.info(
f"Index parameters: m={self.case_config.index_param()['m']}, "
f"ef_construct={self.case_config.index_param()['ef_construct']}, "
f"on_disk={self.case_config.index_param()['on_disk']}"
)

# If the on_disk is true, we enable both on disk index and vectors.
try:
qdrant_client.create_collection(
Expand All @@ -104,10 +109,10 @@ def _create_collection(self, dim: int, qdrant_client: QdrantClient):
on_disk=self.case_config.index_param()["on_disk"],
),
hnsw_config=HnswConfigDiff(
m = self.case_config.index_param()["m"],
m=self.case_config.index_param()["m"],
ef_construct=self.case_config.index_param()["ef_construct"],
on_disk=self.case_config.index_param()["on_disk"],
)
),
)

qdrant_client.create_payload_index(
Expand All @@ -121,7 +126,7 @@ def _create_collection(self, dim: int, qdrant_client: QdrantClient):
return
log.warning(f"Failed to create collection: {self.collection_name} error: {e}")
raise e from None

def optimize(self, data_size: int | None = None):
assert self.client, "Please call self.init() before"
# wait for vectors to be fully indexed
Expand All @@ -139,11 +144,11 @@ def optimize(self, data_size: int | None = None):
)
log.info(msg)
return

except Exception as e:
log.warning(f"QdrantCloud ready to search error: {e}")
raise e from None

def insert_embeddings(
self,
embeddings: Iterable[list[float]],
Expand All @@ -163,7 +168,7 @@ def insert_embeddings(
assert self.client is not None
assert len(embeddings) == len(metadata)
insert_count = 0

# disable indexing for quick insertion
self.client.update_collection(
collection_name=self.collection_name,
Expand All @@ -185,13 +190,13 @@ def insert_embeddings(
collection_name=self.collection_name,
optimizer_config=OptimizersConfigDiff(indexing_threshold=100),
)

except Exception as e:
log.info(f"Failed to insert data, {e}")
return insert_count, e
else:
return insert_count, None

def search_embedding(
self,
query: list[float],
Expand All @@ -203,7 +208,7 @@ def search_embedding(
Should call self.init() first.
"""
assert self.client is not None

f = None
if filters:
f = Filter(
Expand All @@ -215,17 +220,13 @@ def search_embedding(
),
),
],
)
res = (
self.client.query_points(
collection_name=self.collection_name,
query=query,
limit=k,
query_filter=f,
search_params=SearchParams(**self.search_parameter),

).points
)

return [result.id for result in res]
)
res = self.client.query_points(
collection_name=self.collection_name,
query=query,
limit=k,
query_filter=f,
search_params=SearchParams(**self.search_parameter),
).points

return [result.id for result in res]
23 changes: 11 additions & 12 deletions vectordb_bench/backend/clients/weaviate_cloud/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,32 @@
class WeaviateTypedDict(CommonTypedDict):
api_key: Annotated[
str,
click.option("--api-key", type=str, help="Weaviate api key", required=False, default=''),
click.option("--api-key", type=str, help="Weaviate api key", required=False, default=""),
]
url: Annotated[
str,
click.option("--url", type=str, help="Weaviate url", required=True),
]
no_auth: Annotated[
bool,
click.option("--no-auth", is_flag=True, help="Do not use api-key, set it to true if you are using a local setup. Default is False.", default=False),
click.option(
"--no-auth",
is_flag=True,
help="Do not use api-key, set it to true if you are using a local setup. Default is False.",
default=False,
),
]
m: Annotated[
int,
click.option(
"--m", type=int, default=16, help="HNSW index parameter m."
),
click.option("--m", type=int, default=16, help="HNSW index parameter m."),
]
ef_construct: Annotated[
int,
click.option(
"--ef-construction", type=int, default=256, help="HNSW index parameter ef_construction"
),
click.option("--ef-construction", type=int, default=256, help="HNSW index parameter ef_construction"),
]
ef: Annotated[
int,
click.option(
"--ef", type=int, default=256, help="HNSW index parameter ef for search"
),
click.option("--ef", type=int, default=256, help="HNSW index parameter ef for search"),
]


Expand All @@ -54,7 +53,7 @@ def Weaviate(**parameters: Unpack[WeaviateTypedDict]):
db=DB.WeaviateCloud,
db_config=WeaviateConfig(
db_label=parameters["db_label"],
api_key=SecretStr(parameters["api_key"]) if parameters["api_key"] != '' else SecretStr("-"),
api_key=SecretStr(parameters["api_key"]) if parameters["api_key"] != "" else SecretStr("-"),
url=SecretStr(parameters["url"]),
no_auth=parameters["no_auth"],
),
Expand Down
Loading