Skip to content

Commit 24a54ec

Browse files
Cleaned code
1 parent 05d2b4c commit 24a54ec

3 files changed

Lines changed: 31 additions & 22 deletions

File tree

benchmarks/benchmark.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ def stats_set_up(self):
138138
self.latency_vectorq_list: List[float] = []
139139
self.observations_dict: Dict[str, Dict[str, float]] = {}
140140
self.gammas_dict: Dict[str, float] = {}
141+
self.t_hats_dict: Dict[str, float] = {}
142+
self.t_primes_dict: Dict[str, float] = {}
143+
self.var_ts_dict: Dict[str, float] = {}
141144

142145
if self.output_folder_path and not os.path.exists(self.output_folder_path):
143146
os.makedirs(self.output_folder_path)
@@ -309,6 +312,8 @@ def dump_results_to_json(self):
309312
observations_dict = {}
310313
gammas_dict = {}
311314
t_hats_dict = {}
315+
t_primes_dict = {}
316+
var_ts_dict = {}
312317

313318
metadata_objects: List[EmbeddingMetadataObj] = (
314319
self.vectorq.core.cache.get_all_embedding_metadata_objects()
@@ -320,21 +325,29 @@ def dump_results_to_json(self):
320325
)
321326
gammas_dict[metadata_object.embedding_id] = metadata_object.gamma
322327
t_hats_dict[metadata_object.embedding_id] = metadata_object.t_hat
328+
t_primes_dict[metadata_object.embedding_id] = metadata_object.t_prime
329+
var_ts_dict[metadata_object.embedding_id] = metadata_object.var_t
323330

324331
self.observations_dict = observations_dict
325332
self.gammas_dict = gammas_dict
326333
self.t_hats_dict = t_hats_dict
334+
self.t_primes_dict = t_primes_dict
335+
self.var_ts_dict = var_ts_dict
327336

328337
try:
329338
global_observations_dict = (
330339
self.vectorq.core.vectorq_policy.global_observations
331340
)
332341
global_gamma = self.vectorq.core.vectorq_policy.global_gamma
333342
global_t_hat = self.vectorq.core.vectorq_policy.global_t_hat
343+
global_t_prime = self.vectorq.core.vectorq_policy.global_t_prime
344+
global_var_t = self.vectorq.core.vectorq_policy.global_var_t
334345
except Exception:
335346
global_observations_dict = {}
336347
global_gamma = None
337348
global_t_hat = None
349+
global_t_prime = None
350+
global_var_t = None
338351

339352
data = {
340353
"config": {
@@ -355,9 +368,13 @@ def dump_results_to_json(self):
355368
"observations_dict": self.observations_dict,
356369
"gammas_dict": self.gammas_dict,
357370
"t_hats_dict": self.t_hats_dict,
371+
"t_primes_dict": self.t_primes_dict,
372+
"var_ts_dict": self.var_ts_dict,
358373
"global_observations_dict": global_observations_dict,
359374
"global_gamma": global_gamma,
360375
"global_t_hat": global_t_hat,
376+
"global_t_prime": global_t_prime,
377+
"global_var_t": global_var_t,
361378
}
362379

363380
filepath = self.output_folder_path + f"/results_{self.timestamp}.json"

vectorq/vectorq_core/cache/embedding_store/embedding_metadata_storage/embedding_metadata_obj.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def __init__(
2424
self.observations.append((1.0, 1))
2525
self.gamma: float = None
2626
self.t_hat: float = None
27+
self.t_prime: float = None
28+
self.var_t: float = None
2729
##################################################
2830

2931
# VectorQ Heuristic Policy #######################

vectorq/vectorq_core/vectorq_policy/strategies/bayesian.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import logging
21
import random
3-
import time
42
from typing import Dict, List, Tuple
53

64
import numpy as np
@@ -20,7 +18,7 @@ class VectorQBayesianPolicy(VectorQPolicy):
2018
def __init__(self, delta: float, is_global: bool = False):
2119
self.delta: float = delta
2220
self.P_c: float = 1.0 - self.delta
23-
self.epsilon_grid: np.ndarray = np.linspace(1e-6, 1-1e-6, 50)
21+
self.epsilon_grid: np.ndarray = np.linspace(1e-6, 1 - 1e-6, 50)
2422
self.logistic_regression: LogisticRegression = LogisticRegression(
2523
penalty=None, solver="lbfgs", tol=1e-8, max_iter=1000, fit_intercept=False
2624
)
@@ -31,6 +29,8 @@ def __init__(self, delta: float, is_global: bool = False):
3129
self.global_observations.append((1.0, 1))
3230
self.global_gamma: float = None
3331
self.global_t_hat: float = None
32+
self.global_t_prime: float = None
33+
self.global_var_t: float = None
3434

3535
self.variance_map: Dict[int, List[float]] = {
3636
6: 0.012445,
@@ -125,35 +125,24 @@ def select_action(
125125
if len(similarities) < 6 or len(labels) < 6:
126126
return Action.EXPLORE
127127

128-
#start_time = time.time()
129128
t_hat, gamma, var_t = self._estimate_parameters(
130129
similarities=similarities, labels=labels
131130
)
132-
#end_time_parameter_estimation = time.time()
133-
#if self.is_global:
134-
# sorted_observations = sorted(self.global_observations, key=lambda x: x[0])
135-
#else:
136-
# sorted_observations = sorted(metadata.observations, key=lambda x: x[0])
137-
#logging.info(
138-
# f"Embedding {metadata.embedding_id} | similarity: {similarity_score} | Observations: {sorted_observations}"
139-
#)
131+
140132
if t_hat == -1:
141133
return Action.EXPLORE
142134
if self.is_global:
143135
self.global_gamma = gamma
144136
self.global_t_hat = t_hat
137+
self.global_var_t = var_t
145138
else:
146139
metadata.gamma = gamma
147140
metadata.t_hat = t_hat
141+
metadata.var_t = var_t
148142

149-
#start_time = time.time()
150143
tau: float = self._get_tau(
151144
var_t=var_t, s=similarity_score, t_hat=t_hat, metadata=metadata
152145
)
153-
#logging.info(f"t_hat: {t_hat}, gamma: {gamma}, tau: {tau}")
154-
#logging.info(
155-
# f"Parameter estimation: {(end_time_parameter_estimation - start_time):.4f} sec, Tau estimation: {(time.time() - end_time_parameter_estimation):.4f} sec\n"
156-
#)
157146

158147
u: float = random.uniform(0, 1)
159148
if u <= tau:
@@ -235,7 +224,6 @@ def _get_var_t(
235224
else:
236225
max_observations = max(self.variance_map.keys())
237226
var_t = self.variance_map[max_observations]
238-
#logging.info(f"var_t (map): {round(var_t, 5)}")
239227
return var_t
240228
else:
241229
p = self.logistic_regression.predict_proba(X)[:, 1]
@@ -248,7 +236,6 @@ def _get_var_t(
248236

249237
var_t_hat = float(grad @ cov_beta @ grad)
250238
var_t_hat = max(0.0, var_t_hat)
251-
#logging.info(f"var_t_hat (delta method): {round(var_t_hat, 5)}")
252239
return var_t_hat
253240

254241
def _get_tau(
@@ -274,13 +261,17 @@ def _get_tau(
274261
else:
275262
likelihoods = self._likelihood(s=s, t=t_primes, gamma=metadata.gamma)
276263
alpha_lower_bounds = (1 - self.epsilon_grid) * likelihoods
277-
#logging.info(f"alpha_lower_bounds: {alpha_lower_bounds}")
264+
278265
taus = 1 - (1 - self.P_c) / (1 - alpha_lower_bounds)
266+
if self.is_global:
267+
self.global_t_prime = t_primes[np.argmin(taus)]
268+
else:
269+
metadata.t_prime = t_primes[np.argmin(taus)]
279270
return round(np.min(taus), 5)
280271

281272
def _get_t_primes(self, t_hat: float, var_t: float) -> List[float]:
282273
"""
283-
Compute all possible t_prime values
274+
Compute all possible t_prime values.
284275
Args
285276
t_hat: float - The estimated threshold
286277
var_t: float - The variance of t
@@ -295,7 +286,6 @@ def _get_t_primes(self, t_hat: float, var_t: float) -> List[float]:
295286
for i in range(len(self.epsilon_grid))
296287
]
297288
)
298-
#logging.info(f"t_primes: {t_primes}")
299289
return t_primes
300290

301291
def _confidence_interval(

0 commit comments

Comments
 (0)