Skip to content

Commit 05d2b4c

Browse files
Fixed epsilon grid
1 parent 0b6c4aa commit 05d2b4c

2 files changed

Lines changed: 24 additions & 25 deletions

File tree

benchmarks/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
########################################################################################################################
4949

5050
# Benchmark Config
51-
MAX_SAMPLES: int = 45000
51+
MAX_SAMPLES: int = 10000
5252
CONFIDENCE_INTERVALS_ITERATIONS: int = 1
5353
EMBEDDING_MODEL_1 = (
5454
"embedding_1",

vectorq/vectorq_core/vectorq_policy/strategies/bayesian.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class VectorQBayesianPolicy(VectorQPolicy):
2020
def __init__(self, delta: float, is_global: bool = False):
2121
self.delta: float = delta
2222
self.P_c: float = 1.0 - self.delta
23-
self.epsilon_grid: np.ndarray = np.linspace(0.01, 0.5, 50)
23+
self.epsilon_grid: np.ndarray = np.linspace(1e-6, 1-1e-6, 50)
2424
self.logistic_regression: LogisticRegression = LogisticRegression(
2525
penalty=None, solver="lbfgs", tol=1e-8, max_iter=1000, fit_intercept=False
2626
)
@@ -33,7 +33,7 @@ def __init__(self, delta: float, is_global: bool = False):
3333
self.global_t_hat: float = None
3434

3535
self.variance_map: Dict[int, List[float]] = {
36-
6: 0.002445,
36+
6: 0.012445,
3737
7: 0.014285,
3838
8: 0.014436,
3939
9: 0.011349,
@@ -125,18 +125,18 @@ def select_action(
125125
if len(similarities) < 6 or len(labels) < 6:
126126
return Action.EXPLORE
127127

128-
start_time = time.time()
128+
#start_time = time.time()
129129
t_hat, gamma, var_t = self._estimate_parameters(
130130
similarities=similarities, labels=labels
131131
)
132-
end_time_parameter_estimation = time.time()
133-
if self.is_global:
134-
sorted_observations = sorted(self.global_observations, key=lambda x: x[0])
135-
else:
136-
sorted_observations = sorted(metadata.observations, key=lambda x: x[0])
137-
logging.info(
138-
f"Embedding {metadata.embedding_id} | similarity: {similarity_score} | Observations: {sorted_observations}"
139-
)
132+
#end_time_parameter_estimation = time.time()
133+
#if self.is_global:
134+
# sorted_observations = sorted(self.global_observations, key=lambda x: x[0])
135+
#else:
136+
# sorted_observations = sorted(metadata.observations, key=lambda x: x[0])
137+
#logging.info(
138+
# f"Embedding {metadata.embedding_id} | similarity: {similarity_score} | Observations: {sorted_observations}"
139+
#)
140140
if t_hat == -1:
141141
return Action.EXPLORE
142142
if self.is_global:
@@ -146,14 +146,14 @@ def select_action(
146146
metadata.gamma = gamma
147147
metadata.t_hat = t_hat
148148

149-
start_time = time.time()
149+
#start_time = time.time()
150150
tau: float = self._get_tau(
151151
var_t=var_t, s=similarity_score, t_hat=t_hat, metadata=metadata
152152
)
153-
logging.info(f"t_hat: {t_hat}, gamma: {gamma}, tau: {tau}")
154-
logging.info(
155-
f"Parameter estimation: {(end_time_parameter_estimation - start_time):.4f} sec, Tau estimation: {(time.time() - end_time_parameter_estimation):.4f} sec\n"
156-
)
153+
#logging.info(f"t_hat: {t_hat}, gamma: {gamma}, tau: {tau}")
154+
#logging.info(
155+
# f"Parameter estimation: {(end_time_parameter_estimation - start_time):.4f} sec, Tau estimation: {(time.time() - end_time_parameter_estimation):.4f} sec\n"
156+
#)
157157

158158
u: float = random.uniform(0, 1)
159159
if u <= tau:
@@ -186,7 +186,6 @@ def _estimate_parameters(
186186

187187
t_hat = -intercept / (gamma + 1e-6)
188188
t_hat = float(np.clip(t_hat, 0.0, 1.0))
189-
# gamma = float(max(12.0, gamma))
190189

191190
similarities_col = (
192191
similarities[:, 1] if similarities.shape[1] > 1 else similarities[:, 0]
@@ -202,7 +201,7 @@ def _estimate_parameters(
202201
intercept=intercept,
203202
)
204203

205-
return round(t_hat, 3), round(gamma, 3), round(var_t, 5)
204+
return round(t_hat, 3), round(gamma, 3), var_t
206205

207206
except Exception as e:
208207
print(f"Logistic regression failed: {e}")
@@ -236,7 +235,7 @@ def _get_var_t(
236235
else:
237236
max_observations = max(self.variance_map.keys())
238237
var_t = self.variance_map[max_observations]
239-
logging.info(f"var_t (map): {round(var_t, 5)}")
238+
#logging.info(f"var_t (map): {round(var_t, 5)}")
240239
return var_t
241240
else:
242241
p = self.logistic_regression.predict_proba(X)[:, 1]
@@ -249,7 +248,7 @@ def _get_var_t(
249248

250249
var_t_hat = float(grad @ cov_beta @ grad)
251250
var_t_hat = max(0.0, var_t_hat)
252-
logging.info(f"var_t_hat (delta method): {round(var_t_hat, 5)}")
251+
#logging.info(f"var_t_hat (delta method): {round(var_t_hat, 5)}")
253252
return var_t_hat
254253

255254
def _get_tau(
@@ -275,9 +274,9 @@ def _get_tau(
275274
else:
276275
likelihoods = self._likelihood(s=s, t=t_primes, gamma=metadata.gamma)
277276
alpha_lower_bounds = (1 - self.epsilon_grid) * likelihoods
278-
logging.info(f"alpha_lower_bounds: {alpha_lower_bounds}")
277+
#logging.info(f"alpha_lower_bounds: {alpha_lower_bounds}")
279278
taus = 1 - (1 - self.P_c) / (1 - alpha_lower_bounds)
280-
return round(np.min(taus), 3)
279+
return round(np.min(taus), 5)
281280

282281
def _get_t_primes(self, t_hat: float, var_t: float) -> List[float]:
283282
"""
@@ -296,7 +295,7 @@ def _get_t_primes(self, t_hat: float, var_t: float) -> List[float]:
296295
for i in range(len(self.epsilon_grid))
297296
]
298297
)
299-
logging.info(f"t_primes: {t_primes}")
298+
#logging.info(f"t_primes: {t_primes}")
300299
return t_primes
301300

302301
def _confidence_interval(
@@ -314,7 +313,7 @@ def _confidence_interval(
314313
"""
315314
z = norm.ppf(quantile)
316315
t_prime = t_hat + z * np.sqrt(var_t)
317-
return round(float(np.clip(t_prime, 0.0, 1.0)), 3)
316+
return float(np.clip(t_prime, 0.0, 1.0))
318317

319318
def _likelihood(self, s: float, t: float, gamma: float) -> float:
320319
"""

0 commit comments

Comments
 (0)