Skip to content

Commit ed47919

Browse files
authored
Merge branch 'main' into deps
2 parents 4663318 + fb72ed2 commit ed47919

15 files changed

Lines changed: 480 additions & 33 deletions
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""
22
Module to benchmark:
3-
* sorters
3+
* sorters with or without ground-truth
44
* some sorting components (clustering, motion, template matching)
55
"""
66

7+
from .residual_analysis import analyse_residual, make_residual_recording
78
from .benchmark_sorter import SorterStudy
9+
from .benchmark_sorter_without_gt import SorterStudyWithoutGroundTruth

src/spikeinterface/benchmark/benchmark_base.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import time
1010

1111

12-
from spikeinterface.core import SortingAnalyzer
12+
from spikeinterface.core import SortingAnalyzer, NumpySorting
1313
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
1414
from spikeinterface import load, create_sorting_analyzer, load_sorting_analyzer
1515
from spikeinterface.widgets import get_some_colors
@@ -118,12 +118,23 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None):
118118
if isinstance(data, tuple):
119119
# old case : rec + sorting
120120
rec, gt_sorting = data
121-
analyzer = create_sorting_analyzer(
122-
gt_sorting, rec, sparse=True, format="binary_folder", folder=local_analyzer_folder
123-
)
124-
analyzer.compute("random_spikes")
125-
analyzer.compute("templates")
126-
analyzer.compute("noise_levels")
121+
122+
if gt_sorting is not None:
123+
analyzer = create_sorting_analyzer(
124+
gt_sorting, rec, sparse=True, format="binary_folder", folder=local_analyzer_folder
125+
)
126+
analyzer.compute("random_spikes")
127+
analyzer.compute("templates")
128+
analyzer.compute("noise_levels")
129+
else:
130+
# some study/benchmark has no GT sorting
131+
# in that case we still need an analyzer for internal API
132+
gt_sorting = NumpySorting.from_samples_and_labels(
133+
[np.array([])], [np.array([])], rec.sampling_frequency, unit_ids=None
134+
)
135+
analyzer = create_sorting_analyzer(
136+
gt_sorting, rec, sparse=False, format="binary_folder", folder=local_analyzer_folder
137+
)
127138
else:
128139
# new case : analzyer
129140
assert isinstance(data, SortingAnalyzer)
@@ -566,7 +577,10 @@ def _save_keys(self, saved_keys, folder):
566577
elif format == "zarr_templates":
567578
self.result[k].to_zarr(folder / k)
568579
elif format == "sorting_analyzer":
569-
pass
580+
analyzer_folder = folder / k
581+
if analyzer_folder.exists():
582+
shutil.rmtree(analyzer_folder)
583+
self.result[k].save_as(format="binary_folder", folder=analyzer_folder)
570584
else:
571585
raise ValueError(f"Save error {k} {format}")
572586

@@ -612,6 +626,10 @@ def load_folder(cls, folder):
612626
if zarr_folder.exists():
613627

614628
result[k] = Templates.from_zarr(zarr_folder)
629+
elif format == "sorting_analyzer":
630+
analyzer_folder = folder / k
631+
if analyzer_folder.exists():
632+
result[k] = load_sorting_analyzer(analyzer_folder)
615633

616634
return result
617635

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
"""
2+
This replace the previous `GroundTruthStudy`
3+
"""
4+
5+
import numpy as np
6+
from spikeinterface.core import NumpySorting, create_sorting_analyzer
7+
from .benchmark_base import Benchmark, BenchmarkStudy, MixinStudyUnitCount
8+
from spikeinterface.sorters import run_sorter
9+
from spikeinterface.comparison import compare_multiple_sorters
10+
11+
from spikeinterface.benchmark import analyse_residual
12+
13+
14+
# TODO later integrate CollisionGTComparison optionally in this class.
15+
16+
17+
class SorterBenchmarkWithoutGroundTruth(Benchmark):
18+
def __init__(self, recording, gt_sorting, params, sorter_folder):
19+
self.recording = recording
20+
self.gt_sorting = gt_sorting
21+
self.params = params
22+
self.sorter_folder = sorter_folder
23+
self.result = {}
24+
25+
def run(self):
26+
# run one sorter sorter_name is must be in params
27+
raw_sorting = run_sorter(recording=self.recording, folder=self.sorter_folder, **self.params)
28+
sorting = NumpySorting.from_sorting(raw_sorting)
29+
self.result = {"sorting": sorting}
30+
31+
def compute_result(self, residulal_peak_threshold=6, **job_kwargs):
32+
33+
sorting = self.result["sorting"]
34+
analyzer = create_sorting_analyzer(sorting, self.recording, sparse=True, format="memory", **job_kwargs)
35+
analyzer.compute("random_spikes")
36+
analyzer.compute("templates")
37+
analyzer.compute("noise_levels")
38+
analyzer.compute({"spike_amplitudes": {}, "amplitude_scalings": {"handle_collisions": False}}, **job_kwargs)
39+
40+
analyzer.compute("quality_metrics", **job_kwargs)
41+
42+
residual, peaks = analyse_residual(
43+
analyzer,
44+
detect_peaks_kwargs=dict(
45+
method="locally_exclusive",
46+
peak_sign="neg",
47+
detect_threshold=residulal_peak_threshold,
48+
),
49+
**job_kwargs,
50+
)
51+
52+
self.result["sorter_analyzer"] = analyzer
53+
self.result["peaks_from_residual"] = peaks
54+
55+
_run_key_saved = [
56+
("sorting", "sorting"),
57+
]
58+
_result_key_saved = [
59+
# note that this multi_comp is the same accros benchmark (cases)
60+
("multi_comp", "pickle"),
61+
("sorter_analyzer", "sorting_analyzer"),
62+
("peaks_from_residual", "npy"),
63+
]
64+
65+
66+
class SorterStudyWithoutGroundTruth(BenchmarkStudy):
67+
"""
68+
This class is an alternative to SorterStudy when the dataset do not have groundtruth.
69+
70+
This is mainly base on the residual analysis.
71+
"""
72+
73+
benchmark_class = SorterBenchmarkWithoutGroundTruth
74+
75+
def create_benchmark(self, key):
76+
dataset_key = self.cases[key]["dataset"]
77+
recording, gt_sorting = self.datasets[dataset_key]
78+
params = self.cases[key]["params"]
79+
sorter_folder = self.folder / "sorters" / self.key_to_str(key)
80+
benchmark = SorterBenchmarkWithoutGroundTruth(recording, gt_sorting, params, sorter_folder)
81+
return benchmark
82+
83+
def _get_comparison_groups(self):
84+
# multicomparison are done on all cases sharing the same dataset key.
85+
case_keys = list(self.cases.keys())
86+
groups = {}
87+
for case_key in case_keys:
88+
data_key = self.cases[case_key]["dataset"]
89+
if data_key not in groups:
90+
groups[data_key] = []
91+
groups[data_key].append(case_key)
92+
return groups
93+
94+
def compute_results(
95+
self, case_keys=None, verbose=False, delta_time=0.4, match_score=0.5, chance_score=0.1, **result_params
96+
):
97+
# Here we need a hack because the results is not computed case by case but all at once
98+
99+
assert case_keys is None, "SorterStudyWithoutGroundTruth do not permit compute_results for sub cases"
100+
101+
# allways the full list
102+
case_keys = list(self.cases.keys())
103+
104+
# First : this do the case by case internally SorterBenchmarkWithoutGroundTruth.compute_result()
105+
BenchmarkStudy.compute_results(self, case_keys=case_keys, verbose=verbose, **result_params)
106+
107+
# Then we need to compute the multicomparison for case that have the same dataset key.
108+
groups = self._get_comparison_groups()
109+
110+
for data_key, group in groups.items():
111+
112+
sorting_list = [self.get_result(key)["sorting"] for key in group]
113+
name_list = [key for key in group]
114+
multi_comp = compare_multiple_sorters(
115+
sorting_list,
116+
name_list=name_list,
117+
delta_time=delta_time,
118+
match_score=0.5,
119+
chance_score=0.1,
120+
agreement_method="count",
121+
n_jobs=-1,
122+
spiketrain_mode="union",
123+
verbose=verbose,
124+
do_matching=True,
125+
)
126+
# and then the same multi comp is stored for each case_key
127+
for key in case_keys:
128+
benchmark = self.benchmarks[key]
129+
benchmark.result["multi_comp"] = multi_comp
130+
benchmark.save_result(self.folder / "results" / self.key_to_str(key))
131+
132+
def plot_residual_peak_amplitudes(self, figsize=None):
133+
import matplotlib.pyplot as plt
134+
135+
groups = self._get_comparison_groups()
136+
colors = self.get_colors()
137+
138+
for data_key, group in groups.items():
139+
fig, ax = plt.subplots(figsize=figsize)
140+
141+
lim0, lim1 = np.inf, -np.inf
142+
143+
for key in group:
144+
peaks = self.get_result(key)["peaks_from_residual"]
145+
146+
lim0 = min(lim0, np.min(peaks["amplitude"]))
147+
lim1 = max(lim1, np.max(peaks["amplitude"]))
148+
149+
bins = np.linspace(lim0, lim1, 200)
150+
if lim1 < 0:
151+
lim1 = 0
152+
if lim0 > 0:
153+
lim0 = 0
154+
155+
for key in group:
156+
peaks = self.get_result(key)["peaks_from_residual"]
157+
print(peaks.size)
158+
print()
159+
count, bins = np.histogram(peaks["amplitude"], bins=bins)
160+
ax.plot(bins[:-1], count, color=colors[key], label=self.cases[key]["label"])
161+
162+
ax.legend()
163+
164+
# def plot_quality_metrics_comparison_on_agreement(self, qm_name='rp_contamination', figsize=None):
165+
# import matplotlib.pyplot as plt
166+
167+
# groups = self._get_comparison_groups()
168+
169+
# for data_key, group in groups.items():
170+
# n = len(group)
171+
# fig, axs = plt.subplots(ncols=n - 1, nrows=n - 1, figsize=figsize, squeeze=False)
172+
# for i, key1 in enumerate(group):
173+
# for j, key2 in enumerate(group):
174+
# if i < j:
175+
# ax = axs[i, j - 1]
176+
# label1 = self.cases[key1]['label']
177+
# label2 = self.cases[key2]['label']
178+
179+
# if i == j - 1:
180+
# ax.set_xlabel(label2)
181+
# ax.set_ylabel(label1)
182+
183+
# multi_comp = self.get_result(key1)['multi_comp']
184+
# comp = multi_comp.comparisons[key1, key2]
185+
186+
# match_12 = comp.hungarian_match_12
187+
# if match_12.dtype.kind =='i':
188+
# mask = match_12.values != -1
189+
# if match_12.dtype.kind =='U':
190+
# mask = match_12.values != ''
191+
192+
# common_unit1_ids = match_12[mask].index
193+
# common_unit2_ids = match_12[mask].values
194+
# metrics1 = self.get_result(key1)["sorter_analyzer"].get_extension("quality_metrics").get_data()
195+
# metrics2 = self.get_result(key2)["sorter_analyzer"].get_extension("quality_metrics").get_data()
196+
197+
# values1 = metrics1.loc[common_unit1_ids, qm_name].values
198+
# values2 = metrics2.loc[common_unit2_ids, qm_name].values
199+
200+
# print(common_unit1_ids, metrics1.columns, values1)
201+
# print(common_unit2_ids, metrics2.columns, values2)
202+
203+
# ax.scatter(values1, values2)
204+
# if i != j - 1:
205+
# ax.set_xlabel("")
206+
# ax.set_ylabel("")
207+
# ax.set_xticks([])
208+
# ax.set_yticks([])
209+
# ax.set_xticklabels([])
210+
# ax.set_yticklabels([])
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from spikeinterface.core.generate import InjectTemplatesRecording
2+
3+
4+
def analyse_residual(
5+
analyzer,
6+
detect_peaks_kwargs=dict(
7+
method="locally_exclusive",
8+
peak_sign="both",
9+
detect_threshold=6.0,
10+
),
11+
**job_kwargs,
12+
):
13+
"""
14+
This create the residual by removing each spike from the recording.
15+
This take in account the spike amplitude scaling, analyzer need "amplitude_scalings" extensions.
16+
Then a peak detector is run on this residual tarces and then number of peaks can be analyzed (the less the better).
17+
18+
This residual is not perfect at the moement because it do not take in the account the jitter per spikes
19+
and so the residual can be high for high amplitude when there is a inherent jitter per spike.
20+
21+
Paramters
22+
----------
23+
analyzer : SortingAnalyzer
24+
25+
Returns
26+
-------
27+
residual : Recording
28+
The resdiual
29+
peaks : np.array
30+
The peaks vector detected on the residual.
31+
32+
"""
33+
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
34+
35+
residual = make_residual_recording(analyzer)
36+
37+
peaks = detect_peaks(residual, **detect_peaks_kwargs, **job_kwargs)
38+
39+
return residual, peaks
40+
41+
42+
def make_residual_recording(analyzer):
43+
"""
44+
This make a lazy recording residual from an anlyzer.
45+
46+
Paramters
47+
----------
48+
analyzer : SortingAnalyzer
49+
50+
Returns
51+
-------
52+
residual : Recording
53+
The resdiual
54+
"""
55+
56+
templates = analyzer.get_extension("templates").get_templates(outputs="Templates")
57+
neg_templates_array = templates.templates_array.copy()
58+
neg_templates_array *= -1
59+
60+
amplitude_factor = analyzer.get_extension("amplitude_scalings").get_data()
61+
62+
residual = InjectTemplatesRecording(
63+
analyzer.sorting,
64+
neg_templates_array,
65+
nbefore=templates.nbefore,
66+
parent_recording=analyzer.recording,
67+
amplitude_factor=amplitude_factor,
68+
)
69+
residual.name = "ResidualRecording"
70+
71+
return residual

0 commit comments

Comments
 (0)