Skip to content

Commit f5cc406

Browse files
Implemented initial Bayesian Logic
1 parent 2c99678 commit f5cc406

25 files changed

Lines changed: 621 additions & 1091 deletions

File tree

benchmarks/_plotter_combined.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,7 +1160,7 @@ def plot_cache_hit_latency_vs_size_comparison(
11601160
linewidth=3,
11611161
label=f"Static Trend: {z_static[0]:.6f}x + {z_static[1]:.6f}",
11621162
)
1163-
except: # noqa: E722
1163+
except Exception:
11641164
pass
11651165

11661166
# Fit dynamic trend line if enough points
@@ -1176,7 +1176,7 @@ def plot_cache_hit_latency_vs_size_comparison(
11761176
linewidth=3,
11771177
label=f"Dynamic Trend: {z_dynamic[0]:.6f}x + {z_dynamic[1]:.6f}",
11781178
)
1179-
except: # noqa: E722
1179+
except Exception: # Replace bare except
11801180
pass
11811181

11821182
plt.xlabel("Cache Size (MB)")

benchmarks/_plotter_individual.py

Lines changed: 131 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
import os
2+
from typing import TYPE_CHECKING
23

34
import matplotlib.pyplot as plt
45
import numpy as np
56
import seaborn as sns
67
from matplotlib.patches import Patch
78

8-
from .benchmark import Benchmark
9+
from vectorq.vectorq_core.cache.embedding_store.embedding_metadata_storage.embedding_metadata_obj import (
10+
EmbeddingMetadataObj,
11+
)
12+
from vectorq.vectorq_core.vectorq_policy.strategies.bayesian import (
13+
VectorQBayesianPolicy,
14+
)
915

16+
# Use TYPE_CHECKING to avoid circular imports
17+
if TYPE_CHECKING:
18+
from benchmarks.benchmark import Benchmark
1019

11-
def plot_error_rate_relative(benchmark: Benchmark, FONT_SIZE=20):
20+
21+
def plot_error_rate_relative(benchmark: "Benchmark", FONT_SIZE=20):
1222
plt.rcParams.update({"font.size": FONT_SIZE})
1323
plt.figure(figsize=(16, 7))
1424
plt.plot(
@@ -42,7 +52,7 @@ def plot_error_rate_relative(benchmark: Benchmark, FONT_SIZE=20):
4252
plt.close()
4353

4454

45-
def plot_error_rate_absolute(benchmark: Benchmark, FONT_SIZE=20):
55+
def plot_error_rate_absolute(benchmark: "Benchmark", FONT_SIZE=20):
4656
plt.rcParams.update({"font.size": FONT_SIZE})
4757
plt.figure(figsize=(16, 7))
4858
plt.plot(benchmark.sample_sizes, benchmark.error_rates_absolute, color="blue")
@@ -69,7 +79,7 @@ def plot_error_rate_absolute(benchmark: Benchmark, FONT_SIZE=20):
6979
plt.close()
7080

7181

72-
def plot_relative_error_rate_step_size_(benchmark: Benchmark, FONT_SIZE=20):
82+
def plot_relative_error_rate_step_size_(benchmark: "Benchmark", FONT_SIZE=20):
7383
plt.rcParams.update({"font.size": FONT_SIZE})
7484
plt.figure(figsize=(16, 7))
7585
plt.plot(
@@ -101,7 +111,7 @@ def plot_relative_error_rate_step_size_(benchmark: Benchmark, FONT_SIZE=20):
101111
plt.close()
102112

103113

104-
def plot_reuse_rate(benchmark: Benchmark, FONT_SIZE=20):
114+
def plot_reuse_rate(benchmark: "Benchmark", FONT_SIZE=20):
105115
plt.rcParams.update({"font.size": FONT_SIZE})
106116
reuse_rates = [
107117
(reused / size * 100) if size > 0 else 0
@@ -133,7 +143,7 @@ def plot_reuse_rate(benchmark: Benchmark, FONT_SIZE=20):
133143
plt.close()
134144

135145

136-
def plot_relative_reuse_rate(benchmark: Benchmark, FONT_SIZE=20):
146+
def plot_relative_reuse_rate(benchmark: "Benchmark", FONT_SIZE=20):
137147
plt.rcParams.update({"font.size": FONT_SIZE})
138148
plt.figure(figsize=(16, 7))
139149
plt.plot(benchmark.sample_sizes, benchmark.relative_reuse_rates, color="blue")
@@ -161,7 +171,7 @@ def plot_relative_reuse_rate(benchmark: Benchmark, FONT_SIZE=20):
161171
plt.close()
162172

163173

164-
def plot_duration_step_size(benchmark: Benchmark, FONT_SIZE=20):
174+
def plot_duration_step_size(benchmark: "Benchmark", FONT_SIZE=20):
165175
plt.rcParams.update({"font.size": FONT_SIZE})
166176
plt.figure(figsize=(16, 7))
167177
plt.plot(
@@ -190,7 +200,7 @@ def plot_duration_step_size(benchmark: Benchmark, FONT_SIZE=20):
190200
plt.close()
191201

192202

193-
def plot_duration_trend(benchmark: Benchmark, FONT_SIZE=20):
203+
def plot_duration_trend(benchmark: "Benchmark", FONT_SIZE=20):
194204
plt.rcParams.update({"font.size": FONT_SIZE})
195205
plt.figure(figsize=(16, 7))
196206
# Convert seconds to minutes
@@ -221,7 +231,7 @@ def plot_duration_trend(benchmark: Benchmark, FONT_SIZE=20):
221231
plt.close()
222232

223233

224-
def plot_precision(benchmark: Benchmark, FONT_SIZE=20):
234+
def plot_precision(benchmark: "Benchmark", FONT_SIZE=20):
225235
plt.rcParams.update({"font.size": FONT_SIZE})
226236
plt.figure(figsize=(16, 7))
227237
plt.plot(benchmark.sample_sizes, benchmark.precision_list, color="blue")
@@ -237,7 +247,7 @@ def plot_precision(benchmark: Benchmark, FONT_SIZE=20):
237247
plt.close()
238248

239249

240-
def plot_recall(benchmark: Benchmark, FONT_SIZE=20):
250+
def plot_recall(benchmark: "Benchmark", FONT_SIZE=20):
241251
plt.rcParams.update({"font.size": FONT_SIZE})
242252
plt.figure(figsize=(16, 7))
243253
plt.plot(benchmark.sample_sizes, benchmark.recall_list, color="blue")
@@ -253,7 +263,7 @@ def plot_recall(benchmark: Benchmark, FONT_SIZE=20):
253263
plt.close()
254264

255265

256-
def plot_accuracy(benchmark: Benchmark, FONT_SIZE=20):
266+
def plot_accuracy(benchmark: "Benchmark", FONT_SIZE=20):
257267
plt.rcParams.update({"font.size": FONT_SIZE})
258268
plt.figure(figsize=(16, 7))
259269
plt.plot(benchmark.sample_sizes, benchmark.accuracy_list, color="blue")
@@ -269,7 +279,7 @@ def plot_accuracy(benchmark: Benchmark, FONT_SIZE=20):
269279
plt.close()
270280

271281

272-
def plot_cache_size(benchmark: Benchmark, FONT_SIZE=20):
282+
def plot_cache_size(benchmark: "Benchmark", FONT_SIZE=20):
273283
"""Plot cache size growth as samples are processed."""
274284
plt.rcParams.update({"font.size": FONT_SIZE})
275285
plt.figure(figsize=(16, 7))
@@ -283,7 +293,7 @@ def plot_cache_size(benchmark: Benchmark, FONT_SIZE=20):
283293
plt.close()
284294

285295

286-
def add_description(benchmark: Benchmark, plt):
296+
def add_description(benchmark: "Benchmark", plt):
287297
if benchmark.is_dynamic_threshold:
288298
description = (
289299
f"VectorQ, rnd_num_ub: {benchmark.rnd_num_ub}, Data Source: {os.path.basename(benchmark.filepath)}\n"
@@ -311,7 +321,7 @@ def add_description(benchmark: Benchmark, plt):
311321
)
312322

313323

314-
def plot_cache_hit_latency_vs_size(benchmark: Benchmark, FONT_SIZE=20):
324+
def plot_cache_hit_latency_vs_size(benchmark: "Benchmark", FONT_SIZE=20):
315325
cache_sizes = []
316326
hit_latencies = []
317327
cache_hits_count = [] # Track the cumulative number of cache hits
@@ -328,27 +338,18 @@ def plot_cache_hit_latency_vs_size(benchmark: Benchmark, FONT_SIZE=20):
328338
print("Not enough cache hits to plot cache hit latency vs cache size")
329339
return
330340

331-
# Remove outliers using IQR method
332341
latencies_array = np.array(hit_latencies)
333342
q1 = np.percentile(latencies_array, 25)
334343
q3 = np.percentile(latencies_array, 75)
335344
iqr = q3 - q1
336345
lower_bound = q1 - 1.5 * iqr
337346
upper_bound = q3 + 1.5 * iqr
338347

339-
# Filter out outliers
340348
outlier_mask = (latencies_array >= lower_bound) & (latencies_array <= upper_bound)
341349
filtered_cache_sizes = np.array(cache_sizes)[outlier_mask]
342350
filtered_hit_latencies = latencies_array[outlier_mask]
343351
filtered_cache_hits_count = np.array(cache_hits_count)[outlier_mask]
344352

345-
# Print how many outliers were removed
346-
num_outliers = len(hit_latencies) - len(filtered_hit_latencies)
347-
if num_outliers > 0:
348-
print(
349-
f"Removed {num_outliers} outliers from hit latencies (out of {len(hit_latencies)})"
350-
)
351-
352353
plt.rcParams.update({"font.size": FONT_SIZE})
353354
plt.figure(figsize=(16, 7))
354355

@@ -380,8 +381,7 @@ def plot_cache_hit_latency_vs_size(benchmark: Benchmark, FONT_SIZE=20):
380381
label=f"Trend: {z[0]:.6f}x + {z[1]:.4f}",
381382
linewidth=4,
382383
)
383-
except: # noqa: E722
384-
# If fitting fails, just continue without the trend line
384+
except Exception:
385385
pass
386386

387387
ax1.set_xlabel("Cache Size (MB)")
@@ -434,7 +434,7 @@ def plot_cache_hit_latency_vs_size(benchmark: Benchmark, FONT_SIZE=20):
434434

435435

436436
# TODO: LGS
437-
def plot_combined_thresholds_and_posteriors(benchmark: Benchmark):
437+
def plot_combined_thresholds_and_posteriors(benchmark: "Benchmark"):
438438
for idx, correct_similarities, incorrect_similarities, posteriors in zip(
439439
benchmark.correct_x.keys(),
440440
benchmark.correct_x.values(),
@@ -514,3 +514,108 @@ def plot_combined_thresholds_and_posteriors(benchmark: Benchmark):
514514
os.makedirs(output_folder_path)
515515
plt.savefig(filename, format="pdf", bbox_inches="tight")
516516
plt.close()
517+
518+
519+
def plot_bayesian_decision_boundary(benchmark: "Benchmark"):
520+
if benchmark.is_dynamic_threshold:
521+
vectorQ = VectorQBayesianPolicy(delta=benchmark.delta)
522+
523+
for idx, observations, gamma in zip(
524+
benchmark.observations.keys(),
525+
benchmark.observations.values(),
526+
benchmark.gammas,
527+
):
528+
if len(observations) == 0:
529+
continue
530+
531+
metadata = EmbeddingMetadataObj(embedding_id=-1, response="None")
532+
metadata.gamma = gamma
533+
534+
similarities = np.array([obs[0] for obs in observations])
535+
labels = np.array([obs[1] for obs in observations])
536+
correct_obs = np.array([obs[0] for obs in observations if obs[1] == 1])
537+
incorrect_obs = np.array([obs[0] for obs in observations if obs[1] == 0])
538+
539+
if (
540+
len(similarities) < 15
541+
or len(labels) < 15
542+
or len(correct_obs) < 3
543+
or len(incorrect_obs) < 3
544+
):
545+
continue
546+
547+
t_hat = vectorQ._estimate_parameters(similarities, labels, metadata)
548+
549+
s_values = np.linspace(0.0, 1.0, 100)
550+
551+
# Calculate tau for each similarity value
552+
tau_values = []
553+
for s in s_values:
554+
tau = vectorQ._get_tau(similarities, labels, s, t_hat, metadata)
555+
tau_values.append(tau)
556+
557+
# Calculate probability for each similarity value
558+
probs = [vectorQ._likelihood(s, t_hat, gamma) for s in s_values]
559+
560+
plt.figure(figsize=(12, 8))
561+
plt.plot(
562+
s_values,
563+
tau_values,
564+
"r-",
565+
linewidth=2,
566+
label="Tau (exploration probability)",
567+
)
568+
plt.plot(
569+
s_values,
570+
probs,
571+
"b--",
572+
linewidth=2,
573+
label=f"Probability curve (γ={gamma})",
574+
)
575+
plt.axvline(
576+
x=t_hat,
577+
color="g",
578+
linestyle="--",
579+
label=f"Decision boundary (t_hat={t_hat:.2f})",
580+
)
581+
582+
plt.scatter(
583+
correct_obs,
584+
[0.05] * len(correct_obs),
585+
color="green",
586+
label="Correct observations",
587+
s=80,
588+
alpha=0.7,
589+
)
590+
plt.scatter(
591+
incorrect_obs,
592+
[0.05] * len(incorrect_obs),
593+
color="red",
594+
label="Incorrect observations",
595+
s=80,
596+
alpha=0.7,
597+
)
598+
599+
plt.xlim([0.0, 1.0])
600+
plt.ylim([0, 1.05])
601+
plt.xlabel("Similarity (s)", fontsize=18)
602+
plt.ylabel("Probability / Tau", fontsize=18)
603+
plt.title(
604+
f"Exploration Probability (Tau) vs. Similarity (δ={vectorQ.delta})",
605+
fontsize=18,
606+
)
607+
plt.grid(True, alpha=0.3)
608+
plt.legend(fontsize=18)
609+
plt.tight_layout(rect=[0, 0.05, 1, 1])
610+
611+
output_folder_path = (
612+
benchmark.output_folder_path + "/bayesian_decision_boundary/"
613+
)
614+
filename = (
615+
benchmark.output_folder_path
616+
+ f"/bayesian_decision_boundary/decision_boundary_embedding_{idx}_{benchmark.timestamp}.pdf"
617+
)
618+
if output_folder_path and not os.path.exists(output_folder_path):
619+
os.makedirs(output_folder_path)
620+
plt.savefig(filename, format="pdf", bbox_inches="tight")
621+
plt.close()

0 commit comments

Comments
 (0)