-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_rem.py
More file actions
186 lines (158 loc) · 6.03 KB
/
train_rem.py
File metadata and controls
186 lines (158 loc) · 6.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# model input: <query, dist_start, dist_1st, dist_10th, 1st_to_start, 10th_to_start>
# model output: log2(predicted steps)
import matplotlib as mpl
mpl.use('Agg') # noqa
import matplotlib.pyplot as plt
import numpy as np
import argparse
import os
import pickle
from tqdm import tqdm
import time
import json
from scipy import stats
from benchmark.datasets import DATASETS
from benchmark.algorithms.definitions import get_definitions
from benchmark.plotting.metrics import all_metrics as metrics
from benchmark.plotting.metrics import get_all_recall_values, get_count_at_certain_recall
from benchmark.plotting.utils import (get_plot_label, compute_metrics,
create_linestyles, create_pointset)
from benchmark.results import (store_results, load_all_results, load_all_results_without_read,
get_result_filename, get_unique_algorithms)
from benchmark.dataset_io import knn_result_read
import benchmark.streaming.compute_gt
from benchmark.streaming.load_runbook import load_runbook
from benchmark.utils import read_gt_fromdir
from meta_analysis import read_float_arg_from_filename, read_int_arg_from_filename
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--dataset',
metavar="DATASET",
required=True)
parser.add_argument(
'--count',
default=-1,
type=int)
parser.add_argument(
'--definitions',
metavar='FILE',
help='load algorithm definitions from FILE',
default='algos-2021.yaml')
parser.add_argument(
'--neurips23track',
choices=['filter', 'ood', 'sparse', 'streaming', 'none'],
default='none'
)
parser.add_argument(
'--runbook_path',
metavar='FILE',
help='paths to runbooks',
)
parser.add_argument(
'--results_base_path',
type=str,
default='results'
)
parser.add_argument(
'--model_base_path',
type=str,
required=True
)
parser.add_argument(
'--private-query',
help='Use the private queries and ground truth',
action='store_true')
parser.add_argument(
'--mode',
type=str,
required=True)
parser.add_argument(
'--filtered',
help='Use filtered queries.',
action='store_true'
)
parser.add_argument(
'--label_file',
type=str,
default=None,
help='Path to the label file.'
)
parser.add_argument(
'--filter_label_file',
type=str,
default=None,
help='Path to the filter file.'
)
parser.add_argument(
'--num_queries',
type=int,
default=0,
help='Number of queries to run.'
)
args = parser.parse_args()
assert args.mode.startswith("train_rem"), "train_rem.py only supports train_rem mode"
dataset = DATASETS[args.dataset]()
dim = dataset.d
if args.count == -1:
args.count = dataset.default_count()
count = int(args.count)
Q = dataset.get_training_queries().astype(np.float32)
if args.num_queries > 0:
Q = Q[:args.num_queries]
nq = Q.shape[0]
print(fr"Got {nq} queries")
max_pts, runbook = load_runbook(args.dataset, dataset.nb, args.runbook_path)
results = load_all_results(args.dataset, count, neurips23track=args.neurips23track, runbook_path=args.runbook_path, \
filtered=args.filtered, label_file=args.label_file, filter_label_file=args.filter_label_file, base_path=args.results_base_path)
rem_table = {}
step_latency = 0
start_time = time.time()
for i, (fileroot, filename, properties, run) in enumerate(results):
if not filename.startswith(args.mode):
continue
print(f"from {fileroot}/{filename}")
if fileroot.split("/")[-1].endswith('.txt'):
label_file, filter_label_file = fileroot.split("/")[-2], fileroot.split("/")[-1]
else:
label_file, filter_label_file = "", ""
gt_dir = benchmark.streaming.compute_gt.gt_dir(dataset, args.runbook_path, label_file, filter_label_file)
for i in range(0, properties['num_searches']):
search_step_id = properties['step_' + str(i)]
step_suffix = str(search_step_id)
step_latency += properties['latency_step_' + step_suffix] / 10**6
neighbors = np.array(run['neighbors_step' + step_suffix])
groundtruths, groundtruth_distances = read_gt_fromdir(gt_dir, step_suffix, count, train=True)
mean_recall = 0
for (query_id, (
neighbors_per_query, groundtruths_per_query,
)) in tqdm(enumerate(zip(
neighbors, groundtruths,
)), total=nq, desc="Generating training data"):
recall_per_query = 0
for rank, groundtruth in enumerate(groundtruths_per_query):
if groundtruth in neighbors_per_query:
rank_in_neighbors = np.where(neighbors_per_query == groundtruth)[0][0]
recall_per_query += 1
mean_recall += recall_per_query
mean_recall = mean_recall / (count * nq)
ef = read_int_arg_from_filename(filename, "efSearch", 0)
rem_table[mean_recall] = ef
train_latency = time.time() - start_time
# Dump ef_recall_map
if not os.path.exists(args.model_base_path):
os.makedirs(args.model_base_path)
rem_table_path = os.path.join(args.model_base_path, "rem_table.txt")
with open(rem_table_path, "w") as f:
# sort rem_table by recall
rem_table = {k: v for k, v in sorted(rem_table.items(), key=lambda item: item[0])}
for recall, ef in rem_table.items():
f.write(f"{recall},{ef}\n")
latency_path = os.path.join(args.model_base_path, "latency.json")
with open(latency_path, "w") as f:
json.dump({
"run_latency": step_latency,
"generate_latency": 0,
"train_latency": train_latency,
"mse": 0,
}, f, indent=4)