diff --git a/vectordb_bench/backend/dataset.py b/vectordb_bench/backend/dataset.py index 62700b0fa..f90580dc6 100644 --- a/vectordb_bench/backend/dataset.py +++ b/vectordb_bench/backend/dataset.py @@ -220,10 +220,12 @@ def prepare( train_files = utils.compose_train_files(file_count, use_shuffled) all_files = train_files - gt_file, test_file = None, None + test_file = "test.parquet" + all_files.extend([test_file]) + gt_file = None if self.data.with_gt: - gt_file, test_file = utils.compose_gt_file(filters), "test.parquet" - all_files.extend([gt_file, test_file]) + gt_file = utils.compose_gt_file(filters) + all_files.extend([gt_file]) if not self.data.is_custom: source.reader().read( @@ -232,8 +234,10 @@ def prepare( local_ds_root=self.data_dir, ) - if gt_file is not None and test_file is not None: + if test_file is not None: self.test_data = self._read_file(test_file) + + if gt_file is not None: self.gt_data = self._read_file(gt_file) prefix = "shuffle_train" if use_shuffled else "train" diff --git a/vectordb_bench/backend/runner/serial_runner.py b/vectordb_bench/backend/runner/serial_runner.py index 365641132..5b418c886 100644 --- a/vectordb_bench/backend/runner/serial_runner.py +++ b/vectordb_bench/backend/runner/serial_runner.py @@ -209,7 +209,8 @@ def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]: ideal_dcg = get_ideal_dcg(self.k) log.debug(f"test dataset size: {len(test_data)}") - log.debug(f"ground truth size: {ground_truth.columns}, shape: {ground_truth.shape}") + if ground_truth is not None: + log.debug(f"ground truth size: {ground_truth.columns}, shape: {ground_truth.shape}") latencies, recalls, ndcgs = [], [], [] for idx, emb in enumerate(test_data): @@ -228,9 +229,13 @@ def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]: latencies.append(time.perf_counter() - s) - gt = ground_truth["neighbors_id"][idx] - recalls.append(calc_recall(self.k, gt[: self.k], results)) - ndcgs.append(calc_ndcg(gt[: self.k], results, ideal_dcg)) + if ground_truth is not None: + gt = ground_truth["neighbors_id"][idx] + recalls.append(calc_recall(self.k, gt[: self.k], results)) + ndcgs.append(calc_ndcg(gt[: self.k], results, ideal_dcg)) + else: + recalls.append(0) + ndcgs.append(0) if len(latencies) % 100 == 0: log.debug(