Skip to content

Commit cb4712c

Browse files
Caroline-an777alwayslove2013
authored andcommitted
custom control file and column name
custom control file and column name custom control file and column name custom control file and column name custom control name of file and column
1 parent cfa7e33 commit cb4712c

6 files changed

Lines changed: 79 additions & 7 deletions

File tree

vectordb_bench/backend/cases.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from vectordb_bench import config
66
from vectordb_bench.backend.clients.api import MetricType
7-
from vectordb_bench.backend.filter import Filter, FilterOp, IntFilter, LabelFilter, non_filter
7+
from vectordb_bench.backend.filter import Filter, FilterOp, IntFilter, LabelFilter, NonFilter, non_filter
88
from vectordb_bench.base import BaseModel
99
from vectordb_bench.frontend.components.custom.getCustomConfig import CustomDatasetConfig
1010

@@ -337,6 +337,7 @@ class PerformanceCustomDataset(PerformanceCase):
337337
case_id: CaseType = CaseType.PerformanceCustomDataset
338338
name: str = "Performance With Custom Dataset"
339339
description: str = ""
340+
gt_file: str
340341
dataset: DatasetManager
341342

342343
def __init__(
@@ -358,15 +359,26 @@ def __init__(
358359
with_gt=dataset_config.with_gt,
359360
dir=dataset_config.dir,
360361
file_num=dataset_config.file_count,
362+
train_file=dataset_config.train_name,
363+
test_file=f"{dataset_config.test_name}.parquet",
364+
train_id_field=dataset_config.train_id_name,
365+
train_vector_field=dataset_config.train_col_name,
366+
test_vector_field=dataset_config.test_col_name,
367+
gt_neighbors_field=dataset_config.gt_col_name,
361368
)
362369
super().__init__(
363370
name=name,
364371
description=description,
365372
load_timeout=load_timeout,
366373
optimize_timeout=optimize_timeout,
374+
gt_file=f"{dataset_config.gt_name}.parquet",
367375
dataset=DatasetManager(data=dataset),
368376
)
369377

378+
@property
379+
def filters(self) -> Filter:
380+
return NonFilter(gt_file_name=self.gt_file)
381+
370382

371383
class StreamingPerformanceCase(Case):
372384
case_id: CaseType = CaseType.StreamingPerformanceCase

vectordb_bench/backend/dataset.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ class CustomDataset(BaseDataset):
8989
file_num: int
9090
is_custom: bool = True
9191
with_remote_resource: bool = False
92+
train_file: str = "train"
93+
train_id_field: str = "id"
94+
train_vector_field: str = "emb"
95+
test_file: str = "test.parquet"
96+
gt_file: str = "neighbors.parquet"
97+
test_vector_field: str = "emb"
98+
gt_neighbors_field: str = "neighbors_id"
9299

93100
@validator("size")
94101
def verify_size(cls, v: int):
@@ -106,6 +113,20 @@ def dir_name(self) -> str:
106113
def file_count(self) -> int:
107114
return self.file_num
108115

116+
@property
117+
def train_files(self) -> list[str]:
118+
train_file, train_count = self.train_file, self.file_count
119+
prefix = f"{train_file}"
120+
train_files = []
121+
if train_count > 1:
122+
prefix_s = [item.strip() for item in prefix.split(",") if item.strip()]
123+
for i in range(train_count):
124+
sub_file = f"{prefix_s[i]}.parquet"
125+
train_files.append(sub_file)
126+
else:
127+
train_files.append(f"{prefix}.parquet")
128+
return train_files
129+
109130

110131
class LAION(BaseDataset):
111132
name: str = "LAION"

vectordb_bench/backend/filter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ def groundtruth_file(self) -> str:
2121
class NonFilter(Filter):
2222
type: FilterOp = FilterOp.NonFilter
2323
filter_rate: float = 0.0
24+
gt_file_name: str = "neighbors.parquet"
2425

2526
@property
2627
def groundtruth_file(self) -> str:
27-
return "neighbors.parquet"
28+
return self.gt_file_name
2829

2930

3031
non_filter = NonFilter()

vectordb_bench/backend/runner/serial_runner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ def task(self) -> int:
5656
log.info(f"({mp.current_process().name:16}) Start inserting embeddings in batch {config.NUM_PER_BATCH}")
5757
start = time.perf_counter()
5858
for data_df in self.dataset:
59-
all_metadata = data_df["id"].tolist()
59+
all_metadata = data_df[self.dataset.data.train_id_field].tolist()
6060

61-
emb_np = np.stack(data_df["emb"])
61+
emb_np = np.stack(data_df[self.dataset.data.train_vector_field])
6262
if self.normalize:
6363
log.debug("normalize the 100k train data")
6464
all_embeddings = (emb_np / np.linalg.norm(emb_np, axis=1)[:, np.newaxis]).tolist()
@@ -175,8 +175,8 @@ def run_endlessness(self) -> int:
175175
# only 1 file
176176
data_df = next(iter(self.dataset))
177177
all_embeddings, all_metadata = (
178-
np.stack(data_df["emb"]).tolist(),
179-
data_df["id"].tolist(),
178+
np.stack(data_df[self.dataset.data.train_vector_field]).tolist(),
179+
data_df[self.dataset.data.train_id_field].tolist(),
180180
)
181181

182182
start_time = time.perf_counter()

vectordb_bench/frontend/components/custom/displayCustomCase.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,38 @@ def displayCustomCase(customCase: CustomCaseConfig, st, key):
2323
"metric type", key=f"{key}_metric_type", options=["L2", "Cosine", "IP"]
2424
)
2525
customCase.dataset_config.file_count = columns[3].number_input(
26-
"train file count", key=f"{key}_file_count", value=customCase.dataset_config.file_count
26+
"train file count",
27+
key=f"{key}_file_count",
28+
value=customCase.dataset_config.file_count,
29+
help="if train file count is more than one, please input all your train file name and split with ','",
30+
)
31+
32+
columns = st.columns(3)
33+
customCase.dataset_config.train_name = columns[0].text_input(
34+
"train file name",
35+
key=f"{key}_train_name",
36+
value=customCase.dataset_config.train_name,
37+
help="if your file and column in the file is not named as previous explanation, please input the real name (for example: if the file name is `tr.parquet` and column name is `embbb`, then input tr and embbb)",
38+
)
39+
customCase.dataset_config.test_name = columns[1].text_input(
40+
"test file name", key=f"{key}_test_name", value=customCase.dataset_config.test_name
41+
)
42+
customCase.dataset_config.gt_name = columns[2].text_input(
43+
"ground truth file name", key=f"{key}_gt_name", value=customCase.dataset_config.gt_name
44+
)
45+
46+
columns = st.columns([1, 1, 2, 2])
47+
customCase.dataset_config.train_id_name = columns[0].text_input(
48+
"train id name", key=f"{key}_train_id_name", value=customCase.dataset_config.train_id_name
49+
)
50+
customCase.dataset_config.train_col_name = columns[1].text_input(
51+
"train emb name", key=f"{key}_train_col_name", value=customCase.dataset_config.train_col_name
52+
)
53+
customCase.dataset_config.test_col_name = columns[2].text_input(
54+
"test emb name", key=f"{key}_test_col_name", value=customCase.dataset_config.test_col_name
55+
)
56+
customCase.dataset_config.gt_col_name = columns[3].text_input(
57+
"ground truth emb name", key=f"{key}_gt_col_name", value=customCase.dataset_config.gt_col_name
2758
)
2859

2960
columns = st.columns(4)

vectordb_bench/frontend/components/custom/getCustomConfig.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ class CustomDatasetConfig(BaseModel):
1414
file_count: int = 1
1515
use_shuffled: bool = False
1616
with_gt: bool = True
17+
train_name: str = "train"
18+
test_name: str = "test"
19+
gt_name: str = "neighbors"
20+
train_id_name: str = "id"
21+
train_col_name: str = "emb"
22+
test_col_name: str = "emb"
23+
gt_col_name: str = "neighbors_id"
1724

1825

1926
class CustomCaseConfig(BaseModel):

0 commit comments

Comments
 (0)