Skip to content

Commit b7bad93

Browse files
alwayslove2013XuanYang-cn
authored andcommitted
fix bugs when use custom_dataset without groundtruth file
Signed-off-by: min.tian <min.tian.cn@gmail.com>
1 parent 7f83936 commit b7bad93

2 files changed

Lines changed: 17 additions & 8 deletions

File tree

vectordb_bench/backend/dataset.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,12 @@ def prepare(
220220
train_files = utils.compose_train_files(file_count, use_shuffled)
221221
all_files = train_files
222222

223-
gt_file, test_file = None, None
223+
test_file = "test.parquet"
224+
all_files.extend([test_file])
225+
gt_file = None
224226
if self.data.with_gt:
225-
gt_file, test_file = utils.compose_gt_file(filters), "test.parquet"
226-
all_files.extend([gt_file, test_file])
227+
gt_file = utils.compose_gt_file(filters)
228+
all_files.extend([gt_file])
227229

228230
if not self.data.is_custom:
229231
source.reader().read(
@@ -232,8 +234,10 @@ def prepare(
232234
local_ds_root=self.data_dir,
233235
)
234236

235-
if gt_file is not None and test_file is not None:
237+
if test_file is not None:
236238
self.test_data = self._read_file(test_file)
239+
240+
if gt_file is not None:
237241
self.gt_data = self._read_file(gt_file)
238242

239243
prefix = "shuffle_train" if use_shuffled else "train"

vectordb_bench/backend/runner/serial_runner.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]:
209209
ideal_dcg = get_ideal_dcg(self.k)
210210

211211
log.debug(f"test dataset size: {len(test_data)}")
212-
log.debug(f"ground truth size: {ground_truth.columns}, shape: {ground_truth.shape}")
212+
if ground_truth is not None:
213+
log.debug(f"ground truth size: {ground_truth.columns}, shape: {ground_truth.shape}")
213214

214215
latencies, recalls, ndcgs = [], [], []
215216
for idx, emb in enumerate(test_data):
@@ -228,9 +229,13 @@ def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]:
228229

229230
latencies.append(time.perf_counter() - s)
230231

231-
gt = ground_truth["neighbors_id"][idx]
232-
recalls.append(calc_recall(self.k, gt[: self.k], results))
233-
ndcgs.append(calc_ndcg(gt[: self.k], results, ideal_dcg))
232+
if ground_truth is not None:
233+
gt = ground_truth["neighbors_id"][idx]
234+
recalls.append(calc_recall(self.k, gt[: self.k], results))
235+
ndcgs.append(calc_ndcg(gt[: self.k], results, ideal_dcg))
236+
else:
237+
recalls.append(0)
238+
ndcgs.append(0)
234239

235240
if len(latencies) % 100 == 0:
236241
log.debug(

0 commit comments

Comments
 (0)