diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py index 53e251e462..9370bcc0bb 100644 --- a/src/spikeinterface/benchmark/benchmark_plot_tools.py +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -432,6 +432,8 @@ def plot_performances_vs_snr( levels_to_keep=None, orientation="vertical", show_legend=True, + with_sigmoid_fit=True, + show_average_by_bin=False, axs=None, ): """ @@ -455,6 +457,10 @@ def plot_performances_vs_snr( The orientation of the plot. show_legend : bool, default True Show legend or not + show_sigmoid_fit : bool, default True + Show sigmoid that fit the performances. + show_average_by_bin : bool, default False + Instead of the sigmoid an average by bins can be plotted. axs : matplotlib.axes.Axes | None, default: None The axs to use for plotting. Should be the same size as len(performance_names). @@ -478,7 +484,12 @@ def plot_performances_vs_snr( raise ValueError("orientation must be 'vertical' or 'horizontal'") if axs is None: - fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=figsize, squeeze=True) + fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=figsize, squeeze=False) + if orientation == "vertical": + axs = axs[:, 0] + else: + axs = axs[0, :] + else: assert len(axs) == len(performance_names), "axs should have the same number of axes as performance_names" fig = axs[0].get_figure() @@ -512,8 +523,12 @@ def plot_performances_vs_snr( analyzer = study.get_sorting_analyzer(dataset_key=snr_dataset_reference) quality_metrics = analyzer.get_extension("quality_metrics").get_data() - x = quality_metrics["snr"].values - y = study.get_result(sub_key)["gt_comparison"].get_performance()[performance_name].values + x = quality_metrics["snr"].to_numpy(dtype="float64") + y = ( + study.get_result(sub_key)["gt_comparison"] + .get_performance()[performance_name] + .to_numpy(dtype="float64") + ) all_xs.append(x) all_ys.append(y) @@ -524,9 +539,17 @@ def plot_performances_vs_snr( ax.scatter(all_xs, all_ys, marker=".", label=label, color=color) ax.set_ylabel(performance_name) - popt = fit_sigmoid(all_xs, all_ys, p0=None) - xfit = np.linspace(0, max(x), 100) - ax.plot(xfit, sigmoid(xfit, *popt), color=color) + if with_sigmoid_fit: + popt = fit_sigmoid(all_xs, all_ys, p0=None) + xfit = np.linspace(0, max(x), 100) + ax.plot(xfit, sigmoid(xfit, *popt), color=color) + + if show_average_by_bin: + from scipy.stats import binned_statistic + + bins = np.linspace(np.min(all_xs), np.max(all_xs), 20) + average, bins, count = binned_statistic(all_xs, all_ys, statistic="mean", bins=bins) + ax.plot(bins[:-1] + (bins[1] - bins[0]) / 2.0, average, color=color) ax.set_ylim(-0.05, 1.05)