@@ -20,7 +20,7 @@ class VectorQBayesianPolicy(VectorQPolicy):
2020 def __init__ (self , delta : float , is_global : bool = False ):
2121 self .delta : float = delta
2222 self .P_c : float = 1.0 - self .delta
23- self .epsilon_grid : np .ndarray = np .linspace (0.01 , 0.5 , 50 )
23+ self .epsilon_grid : np .ndarray = np .linspace (1e-6 , 1 - 1e-6 , 50 )
2424 self .logistic_regression : LogisticRegression = LogisticRegression (
2525 penalty = None , solver = "lbfgs" , tol = 1e-8 , max_iter = 1000 , fit_intercept = False
2626 )
@@ -33,7 +33,7 @@ def __init__(self, delta: float, is_global: bool = False):
3333 self .global_t_hat : float = None
3434
3535 self .variance_map : Dict [int , List [float ]] = {
36- 6 : 0.002445 ,
36+ 6 : 0.012445 ,
3737 7 : 0.014285 ,
3838 8 : 0.014436 ,
3939 9 : 0.011349 ,
@@ -125,18 +125,18 @@ def select_action(
125125 if len (similarities ) < 6 or len (labels ) < 6 :
126126 return Action .EXPLORE
127127
128- start_time = time .time ()
128+ # start_time = time.time()
129129 t_hat , gamma , var_t = self ._estimate_parameters (
130130 similarities = similarities , labels = labels
131131 )
132- end_time_parameter_estimation = time .time ()
133- if self .is_global :
134- sorted_observations = sorted (self .global_observations , key = lambda x : x [0 ])
135- else :
136- sorted_observations = sorted (metadata .observations , key = lambda x : x [0 ])
137- logging .info (
138- f"Embedding { metadata .embedding_id } | similarity: { similarity_score } | Observations: { sorted_observations } "
139- )
132+ # end_time_parameter_estimation = time.time()
133+ # if self.is_global:
134+ # sorted_observations = sorted(self.global_observations, key=lambda x: x[0])
135+ # else:
136+ # sorted_observations = sorted(metadata.observations, key=lambda x: x[0])
137+ # logging.info(
138+ # f"Embedding {metadata.embedding_id} | similarity: {similarity_score} | Observations: {sorted_observations}"
139+ # )
140140 if t_hat == - 1 :
141141 return Action .EXPLORE
142142 if self .is_global :
@@ -146,14 +146,14 @@ def select_action(
146146 metadata .gamma = gamma
147147 metadata .t_hat = t_hat
148148
149- start_time = time .time ()
149+ # start_time = time.time()
150150 tau : float = self ._get_tau (
151151 var_t = var_t , s = similarity_score , t_hat = t_hat , metadata = metadata
152152 )
153- logging .info (f"t_hat: { t_hat } , gamma: { gamma } , tau: { tau } " )
154- logging .info (
155- f"Parameter estimation: { (end_time_parameter_estimation - start_time ):.4f} sec, Tau estimation: { (time .time () - end_time_parameter_estimation ):.4f} sec\n "
156- )
153+ # logging.info(f"t_hat: {t_hat}, gamma: {gamma}, tau: {tau}")
154+ # logging.info(
155+ # f"Parameter estimation: {(end_time_parameter_estimation - start_time):.4f} sec, Tau estimation: {(time.time() - end_time_parameter_estimation):.4f} sec\n"
156+ # )
157157
158158 u : float = random .uniform (0 , 1 )
159159 if u <= tau :
@@ -186,7 +186,6 @@ def _estimate_parameters(
186186
187187 t_hat = - intercept / (gamma + 1e-6 )
188188 t_hat = float (np .clip (t_hat , 0.0 , 1.0 ))
189- # gamma = float(max(12.0, gamma))
190189
191190 similarities_col = (
192191 similarities [:, 1 ] if similarities .shape [1 ] > 1 else similarities [:, 0 ]
@@ -202,7 +201,7 @@ def _estimate_parameters(
202201 intercept = intercept ,
203202 )
204203
205- return round (t_hat , 3 ), round (gamma , 3 ), round ( var_t , 5 )
204+ return round (t_hat , 3 ), round (gamma , 3 ), var_t
206205
207206 except Exception as e :
208207 print (f"Logistic regression failed: { e } " )
@@ -236,7 +235,7 @@ def _get_var_t(
236235 else :
237236 max_observations = max (self .variance_map .keys ())
238237 var_t = self .variance_map [max_observations ]
239- logging .info (f"var_t (map): { round (var_t , 5 )} " )
238+ # logging.info(f"var_t (map): {round(var_t, 5)}")
240239 return var_t
241240 else :
242241 p = self .logistic_regression .predict_proba (X )[:, 1 ]
@@ -249,7 +248,7 @@ def _get_var_t(
249248
250249 var_t_hat = float (grad @ cov_beta @ grad )
251250 var_t_hat = max (0.0 , var_t_hat )
252- logging .info (f"var_t_hat (delta method): { round (var_t_hat , 5 )} " )
251+ # logging.info(f"var_t_hat (delta method): {round(var_t_hat, 5)}")
253252 return var_t_hat
254253
255254 def _get_tau (
@@ -275,9 +274,9 @@ def _get_tau(
275274 else :
276275 likelihoods = self ._likelihood (s = s , t = t_primes , gamma = metadata .gamma )
277276 alpha_lower_bounds = (1 - self .epsilon_grid ) * likelihoods
278- logging .info (f"alpha_lower_bounds: { alpha_lower_bounds } " )
277+ # logging.info(f"alpha_lower_bounds: {alpha_lower_bounds}")
279278 taus = 1 - (1 - self .P_c ) / (1 - alpha_lower_bounds )
280- return round (np .min (taus ), 3 )
279+ return round (np .min (taus ), 5 )
281280
282281 def _get_t_primes (self , t_hat : float , var_t : float ) -> List [float ]:
283282 """
@@ -296,7 +295,7 @@ def _get_t_primes(self, t_hat: float, var_t: float) -> List[float]:
296295 for i in range (len (self .epsilon_grid ))
297296 ]
298297 )
299- logging .info (f"t_primes: { t_primes } " )
298+ # logging.info(f"t_primes: {t_primes}")
300299 return t_primes
301300
302301 def _confidence_interval (
@@ -314,7 +313,7 @@ def _confidence_interval(
314313 """
315314 z = norm .ppf (quantile )
316315 t_prime = t_hat + z * np .sqrt (var_t )
317- return round ( float (np .clip (t_prime , 0.0 , 1.0 )), 3 )
316+ return float (np .clip (t_prime , 0.0 , 1.0 ))
318317
319318 def _likelihood (self , s : float , t : float , gamma : float ) -> float :
320319 """
0 commit comments