11import os
2+ from typing import TYPE_CHECKING
23
34import matplotlib .pyplot as plt
45import numpy as np
56import seaborn as sns
67from matplotlib .patches import Patch
78
8- from .benchmark import Benchmark
9+ from vectorq .vectorq_core .cache .embedding_store .embedding_metadata_storage .embedding_metadata_obj import (
10+ EmbeddingMetadataObj ,
11+ )
12+ from vectorq .vectorq_core .vectorq_policy .strategies .bayesian import (
13+ VectorQBayesianPolicy ,
14+ )
915
16+ # Use TYPE_CHECKING to avoid circular imports
17+ if TYPE_CHECKING :
18+ from benchmarks .benchmark import Benchmark
1019
11- def plot_error_rate_relative (benchmark : Benchmark , FONT_SIZE = 20 ):
20+
21+ def plot_error_rate_relative (benchmark : "Benchmark" , FONT_SIZE = 20 ):
1222 plt .rcParams .update ({"font.size" : FONT_SIZE })
1323 plt .figure (figsize = (16 , 7 ))
1424 plt .plot (
@@ -42,7 +52,7 @@ def plot_error_rate_relative(benchmark: Benchmark, FONT_SIZE=20):
4252 plt .close ()
4353
4454
45- def plot_error_rate_absolute (benchmark : Benchmark , FONT_SIZE = 20 ):
55+ def plot_error_rate_absolute (benchmark : " Benchmark" , FONT_SIZE = 20 ):
4656 plt .rcParams .update ({"font.size" : FONT_SIZE })
4757 plt .figure (figsize = (16 , 7 ))
4858 plt .plot (benchmark .sample_sizes , benchmark .error_rates_absolute , color = "blue" )
@@ -69,7 +79,7 @@ def plot_error_rate_absolute(benchmark: Benchmark, FONT_SIZE=20):
6979 plt .close ()
7080
7181
72- def plot_relative_error_rate_step_size_ (benchmark : Benchmark , FONT_SIZE = 20 ):
82+ def plot_relative_error_rate_step_size_ (benchmark : " Benchmark" , FONT_SIZE = 20 ):
7383 plt .rcParams .update ({"font.size" : FONT_SIZE })
7484 plt .figure (figsize = (16 , 7 ))
7585 plt .plot (
@@ -101,7 +111,7 @@ def plot_relative_error_rate_step_size_(benchmark: Benchmark, FONT_SIZE=20):
101111 plt .close ()
102112
103113
104- def plot_reuse_rate (benchmark : Benchmark , FONT_SIZE = 20 ):
114+ def plot_reuse_rate (benchmark : " Benchmark" , FONT_SIZE = 20 ):
105115 plt .rcParams .update ({"font.size" : FONT_SIZE })
106116 reuse_rates = [
107117 (reused / size * 100 ) if size > 0 else 0
@@ -133,7 +143,7 @@ def plot_reuse_rate(benchmark: Benchmark, FONT_SIZE=20):
133143 plt .close ()
134144
135145
136- def plot_relative_reuse_rate (benchmark : Benchmark , FONT_SIZE = 20 ):
146+ def plot_relative_reuse_rate (benchmark : " Benchmark" , FONT_SIZE = 20 ):
137147 plt .rcParams .update ({"font.size" : FONT_SIZE })
138148 plt .figure (figsize = (16 , 7 ))
139149 plt .plot (benchmark .sample_sizes , benchmark .relative_reuse_rates , color = "blue" )
@@ -161,7 +171,7 @@ def plot_relative_reuse_rate(benchmark: Benchmark, FONT_SIZE=20):
161171 plt .close ()
162172
163173
164- def plot_duration_step_size (benchmark : Benchmark , FONT_SIZE = 20 ):
174+ def plot_duration_step_size (benchmark : " Benchmark" , FONT_SIZE = 20 ):
165175 plt .rcParams .update ({"font.size" : FONT_SIZE })
166176 plt .figure (figsize = (16 , 7 ))
167177 plt .plot (
@@ -190,7 +200,7 @@ def plot_duration_step_size(benchmark: Benchmark, FONT_SIZE=20):
190200 plt .close ()
191201
192202
193- def plot_duration_trend (benchmark : Benchmark , FONT_SIZE = 20 ):
203+ def plot_duration_trend (benchmark : " Benchmark" , FONT_SIZE = 20 ):
194204 plt .rcParams .update ({"font.size" : FONT_SIZE })
195205 plt .figure (figsize = (16 , 7 ))
196206 # Convert seconds to minutes
@@ -221,7 +231,7 @@ def plot_duration_trend(benchmark: Benchmark, FONT_SIZE=20):
221231 plt .close ()
222232
223233
224- def plot_precision (benchmark : Benchmark , FONT_SIZE = 20 ):
234+ def plot_precision (benchmark : " Benchmark" , FONT_SIZE = 20 ):
225235 plt .rcParams .update ({"font.size" : FONT_SIZE })
226236 plt .figure (figsize = (16 , 7 ))
227237 plt .plot (benchmark .sample_sizes , benchmark .precision_list , color = "blue" )
@@ -237,7 +247,7 @@ def plot_precision(benchmark: Benchmark, FONT_SIZE=20):
237247 plt .close ()
238248
239249
240- def plot_recall (benchmark : Benchmark , FONT_SIZE = 20 ):
250+ def plot_recall (benchmark : " Benchmark" , FONT_SIZE = 20 ):
241251 plt .rcParams .update ({"font.size" : FONT_SIZE })
242252 plt .figure (figsize = (16 , 7 ))
243253 plt .plot (benchmark .sample_sizes , benchmark .recall_list , color = "blue" )
@@ -253,7 +263,7 @@ def plot_recall(benchmark: Benchmark, FONT_SIZE=20):
253263 plt .close ()
254264
255265
256- def plot_accuracy (benchmark : Benchmark , FONT_SIZE = 20 ):
266+ def plot_accuracy (benchmark : " Benchmark" , FONT_SIZE = 20 ):
257267 plt .rcParams .update ({"font.size" : FONT_SIZE })
258268 plt .figure (figsize = (16 , 7 ))
259269 plt .plot (benchmark .sample_sizes , benchmark .accuracy_list , color = "blue" )
@@ -269,7 +279,7 @@ def plot_accuracy(benchmark: Benchmark, FONT_SIZE=20):
269279 plt .close ()
270280
271281
272- def plot_cache_size (benchmark : Benchmark , FONT_SIZE = 20 ):
282+ def plot_cache_size (benchmark : " Benchmark" , FONT_SIZE = 20 ):
273283 """Plot cache size growth as samples are processed."""
274284 plt .rcParams .update ({"font.size" : FONT_SIZE })
275285 plt .figure (figsize = (16 , 7 ))
@@ -283,7 +293,7 @@ def plot_cache_size(benchmark: Benchmark, FONT_SIZE=20):
283293 plt .close ()
284294
285295
286- def add_description (benchmark : Benchmark , plt ):
296+ def add_description (benchmark : " Benchmark" , plt ):
287297 if benchmark .is_dynamic_threshold :
288298 description = (
289299 f"VectorQ, rnd_num_ub: { benchmark .rnd_num_ub } , Data Source: { os .path .basename (benchmark .filepath )} \n "
@@ -311,7 +321,7 @@ def add_description(benchmark: Benchmark, plt):
311321 )
312322
313323
314- def plot_cache_hit_latency_vs_size (benchmark : Benchmark , FONT_SIZE = 20 ):
324+ def plot_cache_hit_latency_vs_size (benchmark : " Benchmark" , FONT_SIZE = 20 ):
315325 cache_sizes = []
316326 hit_latencies = []
317327 cache_hits_count = [] # Track the cumulative number of cache hits
@@ -328,27 +338,18 @@ def plot_cache_hit_latency_vs_size(benchmark: Benchmark, FONT_SIZE=20):
328338 print ("Not enough cache hits to plot cache hit latency vs cache size" )
329339 return
330340
331- # Remove outliers using IQR method
332341 latencies_array = np .array (hit_latencies )
333342 q1 = np .percentile (latencies_array , 25 )
334343 q3 = np .percentile (latencies_array , 75 )
335344 iqr = q3 - q1
336345 lower_bound = q1 - 1.5 * iqr
337346 upper_bound = q3 + 1.5 * iqr
338347
339- # Filter out outliers
340348 outlier_mask = (latencies_array >= lower_bound ) & (latencies_array <= upper_bound )
341349 filtered_cache_sizes = np .array (cache_sizes )[outlier_mask ]
342350 filtered_hit_latencies = latencies_array [outlier_mask ]
343351 filtered_cache_hits_count = np .array (cache_hits_count )[outlier_mask ]
344352
345- # Print how many outliers were removed
346- num_outliers = len (hit_latencies ) - len (filtered_hit_latencies )
347- if num_outliers > 0 :
348- print (
349- f"Removed { num_outliers } outliers from hit latencies (out of { len (hit_latencies )} )"
350- )
351-
352353 plt .rcParams .update ({"font.size" : FONT_SIZE })
353354 plt .figure (figsize = (16 , 7 ))
354355
@@ -380,8 +381,7 @@ def plot_cache_hit_latency_vs_size(benchmark: Benchmark, FONT_SIZE=20):
380381 label = f"Trend: { z [0 ]:.6f} x + { z [1 ]:.4f} " ,
381382 linewidth = 4 ,
382383 )
383- except : # noqa: E722
384- # If fitting fails, just continue without the trend line
384+ except Exception :
385385 pass
386386
387387 ax1 .set_xlabel ("Cache Size (MB)" )
@@ -434,7 +434,7 @@ def plot_cache_hit_latency_vs_size(benchmark: Benchmark, FONT_SIZE=20):
434434
435435
436436# TODO: LGS
437- def plot_combined_thresholds_and_posteriors (benchmark : Benchmark ):
437+ def plot_combined_thresholds_and_posteriors (benchmark : " Benchmark" ):
438438 for idx , correct_similarities , incorrect_similarities , posteriors in zip (
439439 benchmark .correct_x .keys (),
440440 benchmark .correct_x .values (),
@@ -514,3 +514,108 @@ def plot_combined_thresholds_and_posteriors(benchmark: Benchmark):
514514 os .makedirs (output_folder_path )
515515 plt .savefig (filename , format = "pdf" , bbox_inches = "tight" )
516516 plt .close ()
517+
518+
519+ def plot_bayesian_decision_boundary (benchmark : "Benchmark" ):
520+ if benchmark .is_dynamic_threshold :
521+ vectorQ = VectorQBayesianPolicy (delta = benchmark .delta )
522+
523+ for idx , observations , gamma in zip (
524+ benchmark .observations .keys (),
525+ benchmark .observations .values (),
526+ benchmark .gammas ,
527+ ):
528+ if len (observations ) == 0 :
529+ continue
530+
531+ metadata = EmbeddingMetadataObj (embedding_id = - 1 , response = "None" )
532+ metadata .gamma = gamma
533+
534+ similarities = np .array ([obs [0 ] for obs in observations ])
535+ labels = np .array ([obs [1 ] for obs in observations ])
536+ correct_obs = np .array ([obs [0 ] for obs in observations if obs [1 ] == 1 ])
537+ incorrect_obs = np .array ([obs [0 ] for obs in observations if obs [1 ] == 0 ])
538+
539+ if (
540+ len (similarities ) < 15
541+ or len (labels ) < 15
542+ or len (correct_obs ) < 3
543+ or len (incorrect_obs ) < 3
544+ ):
545+ continue
546+
547+ t_hat = vectorQ ._estimate_parameters (similarities , labels , metadata )
548+
549+ s_values = np .linspace (0.0 , 1.0 , 100 )
550+
551+ # Calculate tau for each similarity value
552+ tau_values = []
553+ for s in s_values :
554+ tau = vectorQ ._get_tau (similarities , labels , s , t_hat , metadata )
555+ tau_values .append (tau )
556+
557+ # Calculate probability for each similarity value
558+ probs = [vectorQ ._likelihood (s , t_hat , gamma ) for s in s_values ]
559+
560+ plt .figure (figsize = (12 , 8 ))
561+ plt .plot (
562+ s_values ,
563+ tau_values ,
564+ "r-" ,
565+ linewidth = 2 ,
566+ label = "Tau (exploration probability)" ,
567+ )
568+ plt .plot (
569+ s_values ,
570+ probs ,
571+ "b--" ,
572+ linewidth = 2 ,
573+ label = f"Probability curve (γ={ gamma } )" ,
574+ )
575+ plt .axvline (
576+ x = t_hat ,
577+ color = "g" ,
578+ linestyle = "--" ,
579+ label = f"Decision boundary (t_hat={ t_hat :.2f} )" ,
580+ )
581+
582+ plt .scatter (
583+ correct_obs ,
584+ [0.05 ] * len (correct_obs ),
585+ color = "green" ,
586+ label = "Correct observations" ,
587+ s = 80 ,
588+ alpha = 0.7 ,
589+ )
590+ plt .scatter (
591+ incorrect_obs ,
592+ [0.05 ] * len (incorrect_obs ),
593+ color = "red" ,
594+ label = "Incorrect observations" ,
595+ s = 80 ,
596+ alpha = 0.7 ,
597+ )
598+
599+ plt .xlim ([0.0 , 1.0 ])
600+ plt .ylim ([0 , 1.05 ])
601+ plt .xlabel ("Similarity (s)" , fontsize = 18 )
602+ plt .ylabel ("Probability / Tau" , fontsize = 18 )
603+ plt .title (
604+ f"Exploration Probability (Tau) vs. Similarity (δ={ vectorQ .delta } )" ,
605+ fontsize = 18 ,
606+ )
607+ plt .grid (True , alpha = 0.3 )
608+ plt .legend (fontsize = 18 )
609+ plt .tight_layout (rect = [0 , 0.05 , 1 , 1 ])
610+
611+ output_folder_path = (
612+ benchmark .output_folder_path + "/bayesian_decision_boundary/"
613+ )
614+ filename = (
615+ benchmark .output_folder_path
616+ + f"/bayesian_decision_boundary/decision_boundary_embedding_{ idx } _{ benchmark .timestamp } .pdf"
617+ )
618+ if output_folder_path and not os .path .exists (output_folder_path ):
619+ os .makedirs (output_folder_path )
620+ plt .savefig (filename , format = "pdf" , bbox_inches = "tight" )
621+ plt .close ()
0 commit comments