1313import wget
1414from datasets import load_dataset
1515
16-
1716SUPPORTED_CASES = {
1817 "pubmed768d400k" : {
1918 "dataset_name" : "cryptolab-playground/pubmed-arxiv-abstract-embedding-gemma-300m" ,
20- "embedding_model" : "embeddinggemma-300m"
19+ "embedding_model" : "embeddinggemma-300m" ,
2120 },
2221 "bloomberg768d368k" : {
2322 "dataset_name" : "cryptolab-playground/Bloomberg-Financial-News-embedding-gemma-300m" ,
24- "embedding_model" : "embeddinggemma-300m"
23+ "embedding_model" : "embeddinggemma-300m" ,
2524 },
2625 "products512d400k" : {
2726 "dataset_name" : "cryptolab-playground/amazon-products-clip-vit-b-32" ,
28- "embedding_model" : "clip-vit-b-32"
27+ "embedding_model" : "clip-vit-b-32" ,
2928 },
30- "food512d101k" : {
31- "dataset_name" : "cryptolab-playground/food101-clip-vit-b-32" ,
32- "embedding_model" : "clip-vit-b-32"
33- }
29+ "food512d101k" : {"dataset_name" : "cryptolab-playground/food101-clip-vit-b-32" , "embedding_model" : "clip-vit-b-32" },
3430}
3531SUPPORTED_EMBEDDING_MODELS = ["embeddinggemma-300m" , "clip-vit-b-32" ]
3632
@@ -111,9 +107,7 @@ def download_centroids(embedding_model: str, dataset_dir: str) -> None:
111107 raise ValueError (f"Centroids for { embedding_model } currently not available." )
112108
113109 # BASE URL: https://huggingface.co/datasets/cryptolab-playground/gas-centroids
114- dataset_link = (
115- f"https://huggingface.co/datasets/cryptolab-playground/gas-centroids/resolve/main/{ embedding_model } "
116- )
110+ dataset_link = f"https://huggingface.co/datasets/cryptolab-playground/gas-centroids/resolve/main/{ embedding_model } "
117111
118112 # download
119113 os .makedirs (os .path .join (dataset_dir , embedding_model ), exist_ok = True )
@@ -124,10 +118,14 @@ def download_centroids(embedding_model: str, dataset_dir: str) -> None:
124118if __name__ == "__main__" :
125119 args = get_args ()
126120
127- base_dataset_dir = os .environ .get ("DATASET_LOCAL_DIR" , "/tmp/vectordb_bench/dataset" ) if args .dataset_dir is None else args .dataset_dir
121+ base_dataset_dir = (
122+ os .environ .get ("DATASET_LOCAL_DIR" , "/tmp/vectordb_bench/dataset" )
123+ if args .dataset_dir is None
124+ else args .dataset_dir
125+ )
128126 args .dataset_dir = os .path .join (base_dataset_dir , args .dataset_name )
129127 os .makedirs (args .dataset_dir , exist_ok = True )
130-
128+
131129 download_dataset (args .dataset_name , args .dataset_dir )
132130 prepare_neighbors (args .dataset_dir )
133131 download_centroids (SUPPORTED_CASES [args .dataset_name ]["embedding_model" ], args .centroids_dir )
0 commit comments