diff --git a/vectordb_bench/backend/clients/milvus/cli.py b/vectordb_bench/backend/clients/milvus/cli.py index 24a61566f..b9394cd06 100644 --- a/vectordb_bench/backend/clients/milvus/cli.py +++ b/vectordb_bench/backend/clients/milvus/cli.py @@ -29,6 +29,17 @@ class MilvusTypedDict(TypedDict): str | None, click.option("--password", type=str, help="Db password", required=False), ] + num_shards: Annotated[ + int, + click.option( + "--num-shards", + type=int, + help="Number of shards", + required=False, + default=1, + show_default=True, + ), + ] class MilvusAutoIndexTypedDict(CommonTypedDict, MilvusTypedDict): ... @@ -46,6 +57,7 @@ def MilvusAutoIndex(**parameters: Unpack[MilvusAutoIndexTypedDict]): uri=SecretStr(parameters["uri"]), user=parameters["user_name"], password=SecretStr(parameters["password"]), + num_shards=int(parameters["num_shards"]), ), db_case_config=AutoIndexConfig(), **parameters, @@ -64,6 +76,7 @@ def MilvusFlat(**parameters: Unpack[MilvusAutoIndexTypedDict]): uri=SecretStr(parameters["uri"]), user=parameters["user_name"], password=SecretStr(parameters["password"]), + num_shards=int(parameters["num_shards"]), ), db_case_config=FLATConfig(), **parameters, @@ -110,6 +123,7 @@ def MilvusIVFFlat(**parameters: Unpack[MilvusIVFFlatTypedDict]): uri=SecretStr(parameters["uri"]), user=parameters["user_name"], password=SecretStr(parameters["password"]), + num_shards=int(parameters["num_shards"]), ), db_case_config=IVFFlatConfig( nlist=parameters["nlist"], @@ -131,6 +145,7 @@ def MilvusIVFSQ8(**parameters: Unpack[MilvusIVFFlatTypedDict]): uri=SecretStr(parameters["uri"]), user=parameters["user_name"], password=SecretStr(parameters["password"]), + num_shards=int(parameters["num_shards"]), ), db_case_config=IVFSQ8Config( nlist=parameters["nlist"], @@ -156,6 +171,7 @@ def MilvusDISKANN(**parameters: Unpack[MilvusDISKANNTypedDict]): uri=SecretStr(parameters["uri"]), user=parameters["user_name"], password=SecretStr(parameters["password"]), + num_shards=int(parameters["num_shards"]), ), db_case_config=DISKANNConfig( search_list=parameters["search_list"], @@ -184,6 +200,7 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]): uri=SecretStr(parameters["uri"]), user=parameters["user_name"], password=SecretStr(parameters["password"]), + num_shards=int(parameters["num_shards"]), ), db_case_config=GPUIVFFlatConfig( nlist=parameters["nlist"], @@ -218,6 +235,7 @@ def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]): uri=SecretStr(parameters["uri"]), user=parameters["user_name"], password=SecretStr(parameters["password"]), + num_shards=int(parameters["num_shards"]), ), db_case_config=GPUBruteForceConfig( metric_type=parameters["metric_type"], @@ -249,6 +267,7 @@ def MilvusGPUIVFPQ(**parameters: Unpack[MilvusGPUIVFPQTypedDict]): uri=SecretStr(parameters["uri"]), user=parameters["user_name"], password=SecretStr(parameters["password"]), + num_shards=int(parameters["num_shards"]), ), db_case_config=GPUIVFPQConfig( nlist=parameters["nlist"], @@ -288,6 +307,7 @@ def MilvusGPUCAGRA(**parameters: Unpack[MilvusGPUCAGRATypedDict]): uri=SecretStr(parameters["uri"]), user=parameters["user_name"], password=SecretStr(parameters["password"]), + num_shards=int(parameters["num_shards"]), ), db_case_config=GPUCAGRAConfig( intermediate_graph_degree=parameters["intermediate_graph_degree"], diff --git a/vectordb_bench/backend/clients/milvus/config.py b/vectordb_bench/backend/clients/milvus/config.py index 672becf1b..bf5920817 100644 --- a/vectordb_bench/backend/clients/milvus/config.py +++ b/vectordb_bench/backend/clients/milvus/config.py @@ -7,6 +7,7 @@ class MilvusConfig(DBConfig): uri: SecretStr = "http://localhost:19530" user: str | None = None password: SecretStr | None = None + num_shards: int = 1 def to_dict(self) -> dict: return { diff --git a/vectordb_bench/backend/clients/milvus/milvus.py b/vectordb_bench/backend/clients/milvus/milvus.py index 465c51179..7ea7308c5 100644 --- a/vectordb_bench/backend/clients/milvus/milvus.py +++ b/vectordb_bench/backend/clients/milvus/milvus.py @@ -40,7 +40,12 @@ def __init__( from pymilvus import connections - connections.connect(**self.db_config, timeout=30) + connections.connect( + uri=self.db_config.get("uri"), + user=self.db_config.get("user"), + password=self.db_config.get("password"), + timeout=30, + ) if drop_old and utility.has_collection(self.collection_name): log.info(f"{self.name} client drop_old collection: {self.collection_name}") utility.drop_collection(self.collection_name) @@ -59,6 +64,7 @@ def __init__( name=self.collection_name, schema=CollectionSchema(fields), consistency_level="Session", + num_shards=self.db_config.get("num_shards"), ) log.info(f"{self.name} create index: index_params: {self.case_config.index_param()}")