diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index 8070164dd..790da891b 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -22,6 +22,7 @@ class IndexType(str, Enum): DISKANN = "DISKANN" STREAMING_DISKANN = "DISKANN" IVFFlat = "IVF_FLAT" + IVFPQ = "IVF_PQ" IVFSQ8 = "IVF_SQ8" IVF_RABITQ = "IVF_RABITQ" Flat = "FLAT" diff --git a/vectordb_bench/backend/clients/milvus/config.py b/vectordb_bench/backend/clients/milvus/config.py index 07cd9aad8..672becf1b 100644 --- a/vectordb_bench/backend/clients/milvus/config.py +++ b/vectordb_bench/backend/clients/milvus/config.py @@ -207,6 +207,27 @@ def search_param(self) -> dict: } +class IVFPQConfig(MilvusIndexConfig, DBCaseConfig): + nlist: int + nprobe: int | None = None + m: int = 32 + nbits: int = 8 + index: IndexType = IndexType.IVFPQ + + def index_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "index_type": self.index.value, + "params": {"nlist": self.nlist, "m": self.m, "nbits": self.nbits}, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "params": {"nprobe": self.nprobe}, + } + + class IVFSQ8Config(MilvusIndexConfig, DBCaseConfig): nlist: int nprobe: int | None = None @@ -397,6 +418,7 @@ def search_param(self) -> dict: IndexType.HNSW_PRQ: HNSWPRQConfig, IndexType.DISKANN: DISKANNConfig, IndexType.IVFFlat: IVFFlatConfig, + IndexType.IVFPQ: IVFPQConfig, IndexType.IVFSQ8: IVFSQ8Config, IndexType.IVF_RABITQ: IVFRABITQConfig, IndexType.Flat: FLATConfig, diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index bd59c3470..da5e91d91 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -168,6 +168,7 @@ class CaseConfigInput(BaseModel): IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value, IndexType.IVFFlat.value, + IndexType.IVFPQ.value, IndexType.IVFSQ8.value, IndexType.IVF_RABITQ.value, IndexType.DISKANN.value, @@ -631,6 +632,7 @@ class CaseConfigInput(BaseModel): isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [ IndexType.IVFFlat.value, + IndexType.IVFPQ.value, IndexType.IVFSQ8.value, IndexType.IVF_RABITQ.value, IndexType.GPU_IVF_FLAT.value, @@ -650,6 +652,7 @@ class CaseConfigInput(BaseModel): isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [ IndexType.IVFFlat.value, + IndexType.IVFPQ.value, IndexType.IVFSQ8.value, IndexType.IVF_RABITQ.value, IndexType.GPU_IVF_FLAT.value, @@ -662,12 +665,12 @@ class CaseConfigInput(BaseModel): label=CaseConfigParamType.m, inputType=InputType.Number, inputConfig={ - "min": 0, + "min": 1, "max": 65536, - "value": 0, + "value": 32, }, isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.GPU_IVF_PQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value], + in [IndexType.GPU_IVF_PQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value, IndexType.IVFPQ.value], ) @@ -680,7 +683,7 @@ class CaseConfigInput(BaseModel): "value": 8, }, isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.GPU_IVF_PQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value], + in [IndexType.GPU_IVF_PQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value, IndexType.IVFPQ.value], ) CaseConfigParamInput_NRQ = CaseConfigInput(