forked from zilliztech/VectorDBBench
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprepare_dataset.py
More file actions
114 lines (95 loc) · 3.74 KB
/
prepare_dataset.py
File metadata and controls
114 lines (95 loc) · 3.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import wget
import argparse
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from datasets import load_dataset
import faiss
def get_args():
parser = argparse.ArgumentParser(
description="Prepare dataset and ground truth neighbors for benchmarking."
)
parser.add_argument(
"-d", "--dataset-name",
type=str,
default="cryptolab-playground/pubmed-arxiv-abstract-embedding-gemma-300m",
help="Huggingface dataset name to download.",
choices=[
"cryptolab-playground/pubmed-arxiv-abstract-embedding-gemma-300m",
"cryptolab-playground/Bloomberg-Financial-News-embedding-gemma-300m",
],
)
parser.add_argument(
"--dataset-dir",
type=str,
default="./dataset/pubmed768d400k",
help="Dataset directory to save the dataset and neighbors.",
)
parser.add_argument(
"-e", "--embedding-model",
type=str,
default="embeddinggemma-300m",
help="Embedding model name to download centroids for.",
)
parser.add_argument(
"--centroids-dir",
type=str,
default="./centroids",
help="Directory to save the centroids and tree info.",
)
return parser.parse_args()
def download_dataset(
dataset_name: str,
output_dir: str = "./dataset/pubmed768d400k"
) -> None:
"""Download dataset from Huggingface and save as Parquet files."""
# load dataset
ds = load_dataset(dataset_name)
train = ds["train"].to_pandas()
test = ds["test"].to_pandas()
# write to parquet
train_table = pa.Table.from_pandas(train)
pq.write_table(train_table, f"{output_dir}/train.parquet")
test_table = pa.Table.from_pandas(test)
pq.write_table(test_table, f"{output_dir}/test.parquet")
def prepare_neighbors(
data_dir: str = "./dataset/pubmed768d400k",
) -> None:
"""Prepare ground truth neighbors using brute-force flat search and save as Parquet."""
# load dataset
train = pd.read_parquet(f"{data_dir}/train.parquet")
test = pd.read_parquet(f"{data_dir}/test.parquet")
train = np.stack(train["emb"].to_list()).astype("float32")
test = np.stack(test["emb"].to_list()).astype("float32")
dim = train.shape[1]
# flat search
index = faiss.IndexFlatIP(dim)
index.add(train)
k = len(test)
distances, indices = index.search(test, k)
print(distances.shape, indices.shape)
# save flat search result as neighbors
df = pd.DataFrame({
"id": np.arange(len(indices)),
"neighbors_id": indices.tolist()
})
table = pa.Table.from_pandas(df)
pq.write_table(table, f"{data_dir}/neighbors.parquet")
def download_centroids(embedding_model: str, dataset_dir: str) -> None:
"""Download pre-computed centroids and tree info for GAS VCT index."""
if embedding_model != "embeddinggemma-300m":
raise ValueError(f"Centroids for {embedding_model} currently not available.")
# https://huggingface.co/datasets/cryptolab-playground/gas-centroids
dataset_link = f"https://huggingface.co/datasets/cryptolab-playground/gas-centroids/resolve/main/{embedding_model}"
# download
os.makedirs(os.path.join(dataset_dir, embedding_model), exist_ok=True)
wget.download(f"{dataset_link}/centroids.npy", out=os.path.join(dataset_dir, embedding_model, "centroids.npy"))
wget.download(f"{dataset_link}/tree_info.pkl", out=os.path.join(dataset_dir, embedding_model, "tree_info.pkl"))
if __name__ == "__main__":
args = get_args()
os.makedirs(args.dataset_dir, exist_ok=True)
download_dataset(args.dataset_name, args.dataset_dir)
prepare_neighbors(args.dataset_dir)
download_centroids(args.embedding_model, args.centroids_dir)