1- import logging
21import random
3- import time
42from typing import Dict , List , Tuple
53
64import numpy as np
@@ -20,7 +18,7 @@ class VectorQBayesianPolicy(VectorQPolicy):
2018 def __init__ (self , delta : float , is_global : bool = False ):
2119 self .delta : float = delta
2220 self .P_c : float = 1.0 - self .delta
23- self .epsilon_grid : np .ndarray = np .linspace (1e-6 , 1 - 1e-6 , 50 )
21+ self .epsilon_grid : np .ndarray = np .linspace (1e-6 , 1 - 1e-6 , 50 )
2422 self .logistic_regression : LogisticRegression = LogisticRegression (
2523 penalty = None , solver = "lbfgs" , tol = 1e-8 , max_iter = 1000 , fit_intercept = False
2624 )
@@ -31,6 +29,8 @@ def __init__(self, delta: float, is_global: bool = False):
3129 self .global_observations .append ((1.0 , 1 ))
3230 self .global_gamma : float = None
3331 self .global_t_hat : float = None
32+ self .global_t_prime : float = None
33+ self .global_var_t : float = None
3434
3535 self .variance_map : Dict [int , List [float ]] = {
3636 6 : 0.012445 ,
@@ -125,35 +125,24 @@ def select_action(
125125 if len (similarities ) < 6 or len (labels ) < 6 :
126126 return Action .EXPLORE
127127
128- #start_time = time.time()
129128 t_hat , gamma , var_t = self ._estimate_parameters (
130129 similarities = similarities , labels = labels
131130 )
132- #end_time_parameter_estimation = time.time()
133- #if self.is_global:
134- # sorted_observations = sorted(self.global_observations, key=lambda x: x[0])
135- #else:
136- # sorted_observations = sorted(metadata.observations, key=lambda x: x[0])
137- #logging.info(
138- # f"Embedding {metadata.embedding_id} | similarity: {similarity_score} | Observations: {sorted_observations}"
139- #)
131+
140132 if t_hat == - 1 :
141133 return Action .EXPLORE
142134 if self .is_global :
143135 self .global_gamma = gamma
144136 self .global_t_hat = t_hat
137+ self .global_var_t = var_t
145138 else :
146139 metadata .gamma = gamma
147140 metadata .t_hat = t_hat
141+ metadata .var_t = var_t
148142
149- #start_time = time.time()
150143 tau : float = self ._get_tau (
151144 var_t = var_t , s = similarity_score , t_hat = t_hat , metadata = metadata
152145 )
153- #logging.info(f"t_hat: {t_hat}, gamma: {gamma}, tau: {tau}")
154- #logging.info(
155- # f"Parameter estimation: {(end_time_parameter_estimation - start_time):.4f} sec, Tau estimation: {(time.time() - end_time_parameter_estimation):.4f} sec\n"
156- #)
157146
158147 u : float = random .uniform (0 , 1 )
159148 if u <= tau :
@@ -235,7 +224,6 @@ def _get_var_t(
235224 else :
236225 max_observations = max (self .variance_map .keys ())
237226 var_t = self .variance_map [max_observations ]
238- #logging.info(f"var_t (map): {round(var_t, 5)}")
239227 return var_t
240228 else :
241229 p = self .logistic_regression .predict_proba (X )[:, 1 ]
@@ -248,7 +236,6 @@ def _get_var_t(
248236
249237 var_t_hat = float (grad @ cov_beta @ grad )
250238 var_t_hat = max (0.0 , var_t_hat )
251- #logging.info(f"var_t_hat (delta method): {round(var_t_hat, 5)}")
252239 return var_t_hat
253240
254241 def _get_tau (
@@ -274,13 +261,17 @@ def _get_tau(
274261 else :
275262 likelihoods = self ._likelihood (s = s , t = t_primes , gamma = metadata .gamma )
276263 alpha_lower_bounds = (1 - self .epsilon_grid ) * likelihoods
277- #logging.info(f"alpha_lower_bounds: {alpha_lower_bounds}")
264+
278265 taus = 1 - (1 - self .P_c ) / (1 - alpha_lower_bounds )
266+ if self .is_global :
267+ self .global_t_prime = t_primes [np .argmin (taus )]
268+ else :
269+ metadata .t_prime = t_primes [np .argmin (taus )]
279270 return round (np .min (taus ), 5 )
280271
281272 def _get_t_primes (self , t_hat : float , var_t : float ) -> List [float ]:
282273 """
283- Compute all possible t_prime values
274+ Compute all possible t_prime values.
284275 Args
285276 t_hat: float - The estimated threshold
286277 var_t: float - The variance of t
@@ -295,7 +286,6 @@ def _get_t_primes(self, t_hat: float, var_t: float) -> List[float]:
295286 for i in range (len (self .epsilon_grid ))
296287 ]
297288 )
298- #logging.info(f"t_primes: {t_primes}")
299289 return t_primes
300290
301291 def _confidence_interval (
0 commit comments