Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions src/spikeinterface/benchmark/benchmark_plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,8 @@ def plot_performances_vs_snr(
levels_to_keep=None,
orientation="vertical",
show_legend=True,
with_sigmoid_fit=True,
show_average_by_bin=False,
axs=None,
):
"""
Expand All @@ -455,6 +457,10 @@ def plot_performances_vs_snr(
The orientation of the plot.
show_legend : bool, default True
Show legend or not
show_sigmoid_fit : bool, default True
Show sigmoid that fit the performances.
show_average_by_bin : bool, default False
Instead of the sigmoid an average by bins can be plotted.
axs : matplotlib.axes.Axes | None, default: None
The axs to use for plotting. Should be the same size as len(performance_names).

Expand All @@ -478,7 +484,12 @@ def plot_performances_vs_snr(
raise ValueError("orientation must be 'vertical' or 'horizontal'")

if axs is None:
fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=figsize, squeeze=True)
fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=figsize, squeeze=False)
if orientation == "vertical":
axs = axs[:, 0]
else:
axs = axs[0, :]

else:
assert len(axs) == len(performance_names), "axs should have the same number of axes as performance_names"
fig = axs[0].get_figure()
Expand Down Expand Up @@ -512,8 +523,12 @@ def plot_performances_vs_snr(
analyzer = study.get_sorting_analyzer(dataset_key=snr_dataset_reference)

quality_metrics = analyzer.get_extension("quality_metrics").get_data()
x = quality_metrics["snr"].values
y = study.get_result(sub_key)["gt_comparison"].get_performance()[performance_name].values
x = quality_metrics["snr"].to_numpy(dtype="float64")
y = (
study.get_result(sub_key)["gt_comparison"]
.get_performance()[performance_name]
.to_numpy(dtype="float64")
)
all_xs.append(x)
all_ys.append(y)

Expand All @@ -524,9 +539,17 @@ def plot_performances_vs_snr(
ax.scatter(all_xs, all_ys, marker=".", label=label, color=color)
ax.set_ylabel(performance_name)

popt = fit_sigmoid(all_xs, all_ys, p0=None)
xfit = np.linspace(0, max(x), 100)
ax.plot(xfit, sigmoid(xfit, *popt), color=color)
if with_sigmoid_fit:
popt = fit_sigmoid(all_xs, all_ys, p0=None)
xfit = np.linspace(0, max(x), 100)
ax.plot(xfit, sigmoid(xfit, *popt), color=color)

if show_average_by_bin:
from scipy.stats import binned_statistic

bins = np.linspace(np.min(all_xs), np.max(all_xs), 20)
average, bins, count = binned_statistic(all_xs, all_ys, statistic="mean", bins=bins)
ax.plot(bins[:-1] + (bins[1] - bins[0]) / 2.0, average, color=color)

ax.set_ylim(-0.05, 1.05)

Expand Down
Loading