@@ -432,6 +432,8 @@ def plot_performances_vs_snr(
432432 levels_to_keep = None ,
433433 orientation = "vertical" ,
434434 show_legend = True ,
435+ with_sigmoid_fit = True ,
436+ show_average_by_bin = False ,
435437 axs = None ,
436438):
437439 """
@@ -455,6 +457,10 @@ def plot_performances_vs_snr(
455457 The orientation of the plot.
456458 show_legend : bool, default True
457459 Show legend or not
460+ show_sigmoid_fit : bool, default True
461+ Show sigmoid that fit the performances.
462+ show_average_by_bin : bool, default False
463+ Instead of the sigmoid an average by bins can be plotted.
458464 axs : matplotlib.axes.Axes | None, default: None
459465 The axs to use for plotting. Should be the same size as len(performance_names).
460466
@@ -478,7 +484,12 @@ def plot_performances_vs_snr(
478484 raise ValueError ("orientation must be 'vertical' or 'horizontal'" )
479485
480486 if axs is None :
481- fig , axs = plt .subplots (ncols = ncols , nrows = nrows , figsize = figsize , squeeze = True )
487+ fig , axs = plt .subplots (ncols = ncols , nrows = nrows , figsize = figsize , squeeze = False )
488+ if orientation == "vertical" :
489+ axs = axs [:, 0 ]
490+ else :
491+ axs = axs [0 , :]
492+
482493 else :
483494 assert len (axs ) == len (performance_names ), "axs should have the same number of axes as performance_names"
484495 fig = axs [0 ].get_figure ()
@@ -512,8 +523,12 @@ def plot_performances_vs_snr(
512523 analyzer = study .get_sorting_analyzer (dataset_key = snr_dataset_reference )
513524
514525 quality_metrics = analyzer .get_extension ("quality_metrics" ).get_data ()
515- x = quality_metrics ["snr" ].values
516- y = study .get_result (sub_key )["gt_comparison" ].get_performance ()[performance_name ].values
526+ x = quality_metrics ["snr" ].to_numpy (dtype = "float64" )
527+ y = (
528+ study .get_result (sub_key )["gt_comparison" ]
529+ .get_performance ()[performance_name ]
530+ .to_numpy (dtype = "float64" )
531+ )
517532 all_xs .append (x )
518533 all_ys .append (y )
519534
@@ -524,9 +539,17 @@ def plot_performances_vs_snr(
524539 ax .scatter (all_xs , all_ys , marker = "." , label = label , color = color )
525540 ax .set_ylabel (performance_name )
526541
527- popt = fit_sigmoid (all_xs , all_ys , p0 = None )
528- xfit = np .linspace (0 , max (x ), 100 )
529- ax .plot (xfit , sigmoid (xfit , * popt ), color = color )
542+ if with_sigmoid_fit :
543+ popt = fit_sigmoid (all_xs , all_ys , p0 = None )
544+ xfit = np .linspace (0 , max (x ), 100 )
545+ ax .plot (xfit , sigmoid (xfit , * popt ), color = color )
546+
547+ if show_average_by_bin :
548+ from scipy .stats import binned_statistic
549+
550+ bins = np .linspace (np .min (all_xs ), np .max (all_xs ), 20 )
551+ average , bins , count = binned_statistic (all_xs , all_ys , statistic = "mean" , bins = bins )
552+ ax .plot (bins [:- 1 ] + (bins [1 ] - bins [0 ]) / 2.0 , average , color = color )
530553
531554 ax .set_ylim (- 0.05 , 1.05 )
532555
0 commit comments