Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions vectordb_bench/backend/clients/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ class MetricType(str, Enum):

class IndexType(str, Enum):
HNSW = "HNSW"
HNSW_SQ = "HNSW_SQ"
HNSW_PQ = "HNSW_PQ"
HNSW_PRQ = "HNSW_PRQ"
DISKANN = "DISKANN"
STREAMING_DISKANN = "DISKANN"
IVFFlat = "IVF_FLAT"
IVFSQ8 = "IVF_SQ8"
IVF_RABITQ = "IVF_RABITQ"
Flat = "FLAT"
AUTOINDEX = "AUTOINDEX"
ES_HNSW = "hnsw"
Expand All @@ -31,6 +35,14 @@ class IndexType(str, Enum):
SCANN = "scann"


class SQType(str, Enum):
SQ6 = "SQ6"
SQ8 = "SQ8"
BF16 = "BF16"
FP16 = "FP16"
FP32 = "FP32"


class DBConfig(ABC, BaseModel):
"""DBConfig contains the connection info of vector database

Expand Down
113 changes: 112 additions & 1 deletion vectordb_bench/backend/clients/milvus/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydantic import BaseModel, SecretStr, validator

from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType, SQType


class MilvusConfig(DBConfig):
Expand Down Expand Up @@ -88,6 +88,88 @@ def search_param(self) -> dict:
}


class HNSWSQConfig(HNSWConfig, DBCaseConfig):
index: IndexType = IndexType.HNSW_SQ
sq_type: SQType = SQType.SQ8
refine: bool = True
refine_type: SQType = SQType.FP32
refine_k: float = 1

def index_param(self) -> dict:
return {
"metric_type": self.parse_metric(),
"index_type": self.index.value,
"params": {
"M": self.M,
"efConstruction": self.efConstruction,
"sq_type": self.sq_type.value,
"refine": self.refine,
"refine_type": self.refine_type.value,
},
}

def search_param(self) -> dict:
return {
"metric_type": self.parse_metric(),
"params": {"ef": self.ef, "refine_k": self.refine_k},
}


class HNSWPQConfig(HNSWConfig):
index: IndexType = IndexType.HNSW_PQ
m: int = 32
nbits: int = 8
refine: bool = True
refine_type: SQType = SQType.FP32
refine_k: float = 1

def index_param(self) -> dict:
return {
"metric_type": self.parse_metric(),
"index_type": self.index.value,
"params": {
"M": self.M,
"efConstruction": self.efConstruction,
"m": self.m,
"nbits": self.nbits,
"refine": self.refine,
"refine_type": self.refine_type.value,
},
}

def search_param(self) -> dict:
return {
"metric_type": self.parse_metric(),
"params": {"ef": self.ef, "refine_k": self.refine_k},
}


class HNSWPRQConfig(HNSWPQConfig):
index: IndexType = IndexType.HNSW_PRQ
nrq: int = 2

def index_param(self) -> dict:
return {
"metric_type": self.parse_metric(),
"index_type": self.index.value,
"params": {
"M": self.M,
"efConstruction": self.efConstruction,
"m": self.m,
"nbits": self.nbits,
"nrq": self.nrq,
"refine": self.refine,
"refine_type": self.refine_type.value,
},
}

def search_param(self) -> dict:
return {
"metric_type": self.parse_metric(),
"params": {"ef": self.ef, "refine_k": self.refine_k},
}


class DISKANNConfig(MilvusIndexConfig, DBCaseConfig):
search_list: int | None = None
index: IndexType = IndexType.DISKANN
Expand Down Expand Up @@ -144,6 +226,31 @@ def search_param(self) -> dict:
}


class IVFRABITQConfig(IVFSQ8Config):
index: IndexType = IndexType.IVF_RABITQ
rbq_bits_query: int = 0 # 0, 1, 2, ..., 8
refine: bool = True
refine_type: SQType = SQType.FP32
refine_k: float = 1

def index_param(self) -> dict:
return {
"metric_type": self.parse_metric(),
"index_type": self.index.value,
"params": {
"nlist": self.nlist,
"refine": self.refine,
"refine_type": self.refine_type.value,
},
}

def search_param(self) -> dict:
return {
"metric_type": self.parse_metric(),
"params": {"nprobe": self.nprobe, "rbq_bits_query": self.rbq_bits_query, "refine_k": self.refine_k},
}


class FLATConfig(MilvusIndexConfig, DBCaseConfig):
index: IndexType = IndexType.Flat

Expand Down Expand Up @@ -285,9 +392,13 @@ def search_param(self) -> dict:
_milvus_case_config = {
IndexType.AUTOINDEX: AutoIndexConfig,
IndexType.HNSW: HNSWConfig,
IndexType.HNSW_SQ: HNSWSQConfig,
IndexType.HNSW_PQ: HNSWPQConfig,
IndexType.HNSW_PRQ: HNSWPRQConfig,
IndexType.DISKANN: DISKANNConfig,
IndexType.IVFFlat: IVFFlatConfig,
IndexType.IVFSQ8: IVFSQ8Config,
IndexType.IVF_RABITQ: IVFRABITQConfig,
IndexType.Flat: FLATConfig,
IndexType.GPU_IVF_FLAT: GPUIVFFlatConfig,
IndexType.GPU_IVF_PQ: GPUIVFPQConfig,
Expand Down
4 changes: 3 additions & 1 deletion vectordb_bench/backend/clients/milvus/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
consistency_level="Session",
)

log.info(f"{self.name} create index: index_params: {self.case_config.index_param()}")
col.create_index(
self._vector_field,
self.case_config.index_param(),
Expand All @@ -71,7 +72,7 @@ def __init__(
connections.disconnect("default")

@contextmanager
def init(self) -> None:
def init(self):
"""
Examples:
>>> with self.init():
Expand Down Expand Up @@ -126,6 +127,7 @@ def wait_index():
try:
self.col.compact()
self.col.wait_for_compaction_completed()
log.info("compactation completed. waiting for the rest of index buliding.")
except Exception as e:
log.warning(f"{self.name} compact error: {e}")
if hasattr(e, "code"):
Expand Down
111 changes: 104 additions & 7 deletions vectordb_bench/frontend/config/dbCaseConfigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel
from vectordb_bench.backend.cases import CaseLabel, CaseType
from vectordb_bench.backend.clients import DB
from vectordb_bench.backend.clients.api import IndexType, MetricType
from vectordb_bench.backend.clients.api import IndexType, MetricType, SQType
from vectordb_bench.frontend.components.custom.getCustomConfig import get_custom_configs

from vectordb_bench.models import CaseConfig, CaseConfigParamType
Expand Down Expand Up @@ -164,10 +164,13 @@ class CaseConfigInput(BaseModel):
inputConfig={
"options": [
IndexType.HNSW.value,
IndexType.HNSW_SQ.value,
IndexType.HNSW_PQ.value,
IndexType.HNSW_PRQ.value,
IndexType.IVFFlat.value,
IndexType.IVFSQ8.value,
IndexType.IVF_RABITQ.value,
IndexType.DISKANN.value,
IndexType.STREAMING_DISKANN.value,
IndexType.Flat.value,
IndexType.AUTOINDEX.value,
IndexType.GPU_IVF_FLAT.value,
Expand Down Expand Up @@ -346,9 +349,16 @@ class CaseConfigInput(BaseModel):
"max": 64,
"value": 30,
},
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value,
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
in [
IndexType.HNSW.value,
IndexType.HNSW_SQ.value,
IndexType.HNSW_PQ.value,
IndexType.HNSW_PRQ.value,
],
)


CaseConfigParamInput_m = CaseConfigInput(
label=CaseConfigParamType.m,
inputType=InputType.Number,
Expand All @@ -369,7 +379,62 @@ class CaseConfigInput(BaseModel):
"max": 512,
"value": 360,
},
isDisplayed=lambda config: config[CaseConfigParamType.IndexType] == IndexType.HNSW.value,
isDisplayed=lambda config: config[CaseConfigParamType.IndexType]
in [
IndexType.HNSW.value,
IndexType.HNSW_SQ.value,
IndexType.HNSW_PQ.value,
IndexType.HNSW_PRQ.value,
],
)

CaseConfigParamInput_SQType = CaseConfigInput(
label=CaseConfigParamType.sq_type,
inputType=InputType.Option,
inputHelp="Scalar quantizer type.",
inputConfig={
"options": [SQType.SQ6.value, SQType.SQ8.value, SQType.BF16.value, SQType.FP16.value, SQType.FP32.value]
},
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.HNSW_SQ.value],
)

CaseConfigParamInput_Refine = CaseConfigInput(
label=CaseConfigParamType.refine,
inputType=InputType.Option,
inputHelp="Whether refined data is reserved during index building.",
inputConfig={"options": [True, False]},
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
in [IndexType.HNSW_SQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value, IndexType.IVF_RABITQ.value],
)

CaseConfigParamInput_RefineType = CaseConfigInput(
label=CaseConfigParamType.refine_type,
inputType=InputType.Option,
inputHelp="The data type of the refine index.",
inputConfig={
"options": [SQType.FP32.value, SQType.FP16.value, SQType.BF16.value, SQType.SQ8.value, SQType.SQ6.value]
},
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
in [IndexType.HNSW_SQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value, IndexType.IVF_RABITQ.value]
and config.get(CaseConfigParamType.refine, True),
)

CaseConfigParamInput_RefineK = CaseConfigInput(
label=CaseConfigParamType.refine_k,
inputType=InputType.Float,
inputHelp="The magnification factor of refine compared to k.",
inputConfig={"min": 1.0, "max": 10000.0, "value": 1.0},
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
in [IndexType.HNSW_SQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value, IndexType.IVF_RABITQ.value]
and config.get(CaseConfigParamType.refine, True),
)

CaseConfigParamInput_RBQBitsQuery = CaseConfigInput(
label=CaseConfigParamType.rbq_bits_query,
inputType=InputType.Number,
inputHelp="The magnification factor of refine compared to k.",
inputConfig={"min": 0, "max": 8, "value": 0},
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.IVF_RABITQ.value],
)

CaseConfigParamInput_EFConstruction_Weaviate = CaseConfigInput(
Expand Down Expand Up @@ -519,7 +584,13 @@ class CaseConfigInput(BaseModel):
"max": MAX_STREAMLIT_INT,
"value": 100,
},
isDisplayed=lambda config: config[CaseConfigParamType.IndexType] == IndexType.HNSW.value,
isDisplayed=lambda config: config[CaseConfigParamType.IndexType]
in [
IndexType.HNSW.value,
IndexType.HNSW_SQ.value,
IndexType.HNSW_PQ.value,
IndexType.HNSW_PRQ.value,
],
)

CaseConfigParamInput_EF_Weaviate = CaseConfigInput(
Expand Down Expand Up @@ -561,6 +632,7 @@ class CaseConfigInput(BaseModel):
in [
IndexType.IVFFlat.value,
IndexType.IVFSQ8.value,
IndexType.IVF_RABITQ.value,
IndexType.GPU_IVF_FLAT.value,
IndexType.GPU_IVF_PQ.value,
IndexType.GPU_BRUTE_FORCE.value,
Expand All @@ -579,6 +651,7 @@ class CaseConfigInput(BaseModel):
in [
IndexType.IVFFlat.value,
IndexType.IVFSQ8.value,
IndexType.IVF_RABITQ.value,
IndexType.GPU_IVF_FLAT.value,
IndexType.GPU_IVF_PQ.value,
IndexType.GPU_BRUTE_FORCE.value,
Expand All @@ -593,7 +666,8 @@ class CaseConfigInput(BaseModel):
"max": 65536,
"value": 0,
},
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_IVF_PQ.value],
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
in [IndexType.GPU_IVF_PQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value],
)


Expand All @@ -605,7 +679,20 @@ class CaseConfigInput(BaseModel):
"max": 65536,
"value": 8,
},
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_IVF_PQ.value],
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
in [IndexType.GPU_IVF_PQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value],
)

CaseConfigParamInput_NRQ = CaseConfigInput(
label=CaseConfigParamType.nrq,
inputType=InputType.Number,
inputHelp="The number of residual subquantizers.",
inputConfig={
"min": 1,
"max": 16,
"value": 2,
},
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.HNSW_PRQ.value],
)

CaseConfigParamInput_intermediate_graph_degree = CaseConfigInput(
Expand Down Expand Up @@ -1186,6 +1273,10 @@ class CaseConfigInput(BaseModel):
CaseConfigParamInput_graph_degree,
CaseConfigParamInput_build_algo,
CaseConfigParamInput_cache_dataset_on_device,
CaseConfigParamInput_SQType,
CaseConfigParamInput_Refine,
CaseConfigParamInput_RefineType,
CaseConfigParamInput_NRQ,
]
MilvusPerformanceConfig = [
CaseConfigParamInput_IndexType,
Expand All @@ -1197,6 +1288,8 @@ class CaseConfigInput(BaseModel):
CaseConfigParamInput_Nprobe,
CaseConfigParamInput_M_PQ,
CaseConfigParamInput_Nbits_PQ,
CaseConfigParamInput_RBQBitsQuery,
CaseConfigParamInput_NRQ,
CaseConfigParamInput_intermediate_graph_degree,
CaseConfigParamInput_graph_degree,
CaseConfigParamInput_itopk_size,
Expand All @@ -1207,6 +1300,10 @@ class CaseConfigInput(BaseModel):
CaseConfigParamInput_build_algo,
CaseConfigParamInput_cache_dataset_on_device,
CaseConfigParamInput_refine_ratio,
CaseConfigParamInput_SQType,
CaseConfigParamInput_Refine,
CaseConfigParamInput_RefineType,
CaseConfigParamInput_RefineK,
]

WeaviateLoadConfig = [
Expand Down
Loading