Skip to content

Commit aa8384c

Browse files
committed
oups
2 parents bd9f85b + 53cfbcc commit aa8384c

8 files changed

Lines changed: 49 additions & 59 deletions

File tree

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""
22
Module to benchmark:
3-
* sorters
3+
* sorters with or without ground-truth
44
* some sorting components (clustering, motion, template matching)
55
"""
66

7+
from .residual_analysis import analyse_residual, make_residual_recording
78
from .benchmark_sorter import SorterStudy
9+
from .benchmark_sorter_without_gt import SorterStudyWithoutGroundTruth

src/spikeinterface/benchmark/benchmark_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import time
1010

1111

12-
1312
from spikeinterface.core import SortingAnalyzer, NumpySorting
1413
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
1514
from spikeinterface import load, create_sorting_analyzer, load_sorting_analyzer
@@ -130,7 +129,9 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None):
130129
else:
131130
# some study/benchmark has no GT sorting
132131
# in that case we still need an analyzer for internal API
133-
gt_sorting = NumpySorting.from_samples_and_labels([np.array([])], [np.array([])], rec.sampling_frequency, unit_ids=None)
132+
gt_sorting = NumpySorting.from_samples_and_labels(
133+
[np.array([])], [np.array([])], rec.sampling_frequency, unit_ids=None
134+
)
134135
analyzer = create_sorting_analyzer(
135136
gt_sorting, rec, sparse=False, format="binary_folder", folder=local_analyzer_folder
136137
)

src/spikeinterface/benchmark/benchmark_sorter_without_gt.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,30 +29,24 @@ def run(self):
2929
self.result = {"sorting": sorting}
3030

3131
def compute_result(self, residulal_peak_threshold=6, **job_kwargs):
32-
33-
sorting = self.result['sorting']
34-
analyzer = create_sorting_analyzer(
35-
sorting, self.recording, sparse=True, format="memory", **job_kwargs
36-
)
32+
33+
sorting = self.result["sorting"]
34+
analyzer = create_sorting_analyzer(sorting, self.recording, sparse=True, format="memory", **job_kwargs)
3735
analyzer.compute("random_spikes")
3836
analyzer.compute("templates")
3937
analyzer.compute("noise_levels")
40-
analyzer.compute(
41-
{
42-
"spike_amplitudes" : {},
43-
"amplitude_scalings" : {"handle_collisions": False}
44-
},
45-
**job_kwargs)
38+
analyzer.compute({"spike_amplitudes": {}, "amplitude_scalings": {"handle_collisions": False}}, **job_kwargs)
4639

4740
analyzer.compute("quality_metrics", **job_kwargs)
4841

4942
residual, peaks = analyse_residual(
50-
analyzer, detect_peaks_kwargs=dict(
43+
analyzer,
44+
detect_peaks_kwargs=dict(
5145
method="locally_exclusive",
5246
peak_sign="neg",
5347
detect_threshold=residulal_peak_threshold,
5448
),
55-
**job_kwargs
49+
**job_kwargs,
5650
)
5751

5852
self.result["sorter_analyzer"] = analyzer
@@ -66,11 +60,9 @@ def compute_result(self, residulal_peak_threshold=6, **job_kwargs):
6660
("multi_comp", "pickle"),
6761
("sorter_analyzer", "sorting_analyzer"),
6862
("peaks_from_residual", "npy"),
69-
7063
]
7164

7265

73-
7466
class SorterStudyWithoutGroundTruth(BenchmarkStudy):
7567
"""
7668
This class is an alternative to SorterStudy when the dataset do not have groundtruth.
@@ -87,19 +79,21 @@ def create_benchmark(self, key):
8779
sorter_folder = self.folder / "sorters" / self.key_to_str(key)
8880
benchmark = SorterBenchmarkWithoutGroundTruth(recording, gt_sorting, params, sorter_folder)
8981
return benchmark
90-
82+
9183
def _get_comparison_groups(self):
9284
# multicomparison are done on all cases sharing the same dataset key.
9385
case_keys = list(self.cases.keys())
9486
groups = {}
9587
for case_key in case_keys:
96-
data_key = self.cases[case_key]['dataset']
88+
data_key = self.cases[case_key]["dataset"]
9789
if data_key not in groups:
9890
groups[data_key] = []
9991
groups[data_key].append(case_key)
10092
return groups
10193

102-
def compute_results(self, case_keys=None, verbose=False, delta_time=0.4, match_score=0.5, chance_score=0.1, **result_params):
94+
def compute_results(
95+
self, case_keys=None, verbose=False, delta_time=0.4, match_score=0.5, chance_score=0.1, **result_params
96+
):
10397
# Here we need a hack because the results is not computed case by case but all at once
10498

10599
assert case_keys is None, "SorterStudyWithoutGroundTruth do not permit compute_results for sub cases"
@@ -115,7 +109,7 @@ def compute_results(self, case_keys=None, verbose=False, delta_time=0.4, match_s
115109

116110
for data_key, group in groups.items():
117111

118-
sorting_list = [self.get_result(key)['sorting'] for key in group]
112+
sorting_list = [self.get_result(key)["sorting"] for key in group]
119113
name_list = [key for key in group]
120114
multi_comp = compare_multiple_sorters(
121115
sorting_list,
@@ -132,7 +126,7 @@ def compute_results(self, case_keys=None, verbose=False, delta_time=0.4, match_s
132126
# and then the same multi comp is stored for each case_key
133127
for key in case_keys:
134128
benchmark = self.benchmarks[key]
135-
benchmark.result['multi_comp'] = multi_comp
129+
benchmark.result["multi_comp"] = multi_comp
136130
benchmark.save_result(self.folder / "results" / self.key_to_str(key))
137131

138132
def plot_residual_peak_amplitudes(self, figsize=None):
@@ -143,7 +137,7 @@ def plot_residual_peak_amplitudes(self, figsize=None):
143137

144138
for data_key, group in groups.items():
145139
fig, ax = plt.subplots(figsize=figsize)
146-
140+
147141
lim0, lim1 = np.inf, -np.inf
148142

149143
for key in group:
@@ -158,7 +152,6 @@ def plot_residual_peak_amplitudes(self, figsize=None):
158152
if lim0 > 0:
159153
lim0 = 0
160154

161-
162155
for key in group:
163156
peaks = self.get_result(key)["peaks_from_residual"]
164157
print(peaks.size)
@@ -167,7 +160,7 @@ def plot_residual_peak_amplitudes(self, figsize=None):
167160
ax.plot(bins[:-1], count, color=colors[key], label=self.cases[key]["label"])
168161

169162
ax.legend()
170-
163+
171164
# def plot_quality_metrics_comparison_on_agreement(self, qm_name='rp_contamination', figsize=None):
172165
# import matplotlib.pyplot as plt
173166

@@ -189,13 +182,13 @@ def plot_residual_peak_amplitudes(self, figsize=None):
189182

190183
# multi_comp = self.get_result(key1)['multi_comp']
191184
# comp = multi_comp.comparisons[key1, key2]
192-
185+
193186
# match_12 = comp.hungarian_match_12
194187
# if match_12.dtype.kind =='i':
195188
# mask = match_12.values != -1
196189
# if match_12.dtype.kind =='U':
197190
# mask = match_12.values != ''
198-
191+
199192
# common_unit1_ids = match_12[mask].index
200193
# common_unit2_ids = match_12[mask].values
201194
# metrics1 = self.get_result(key1)["sorter_analyzer"].get_extension("quality_metrics").get_data()
@@ -215,5 +208,3 @@ def plot_residual_peak_amplitudes(self, figsize=None):
215208
# ax.set_yticks([])
216209
# ax.set_xticklabels([])
217210
# ax.set_yticklabels([])
218-
219-

src/spikeinterface/benchmark/residual_analysis.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
from spikeinterface.core.generate import InjectTemplatesRecording
22

33

4-
54
def analyse_residual(
6-
analyzer,
7-
detect_peaks_kwargs=dict(
8-
method="locally_exclusive",
9-
peak_sign="both",
10-
detect_threshold=6.,
11-
12-
),
13-
**job_kwargs
14-
):
5+
analyzer,
6+
detect_peaks_kwargs=dict(
7+
method="locally_exclusive",
8+
peak_sign="both",
9+
detect_threshold=6.0,
10+
),
11+
**job_kwargs,
12+
):
1513
"""
1614
This create the residual by removing each spike from the recording.
1715
This take in account the spike amplitude scaling, analyzer need "amplitude_scalings" extensions.
@@ -41,8 +39,6 @@ def analyse_residual(
4139
return residual, peaks
4240

4341

44-
45-
4642
def make_residual_recording(analyzer):
4743
"""
4844
This make a lazy recording residual from an anlyzer.
@@ -56,7 +52,7 @@ def make_residual_recording(analyzer):
5652
residual : Recording
5753
The resdiual
5854
"""
59-
55+
6056
templates = analyzer.get_extension("templates").get_templates(outputs="Templates")
6157
neg_templates_array = templates.templates_array.copy()
6258
neg_templates_array *= -1

src/spikeinterface/benchmark/tests/test_benchmark_sorter_without_gt.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,12 @@ def _create_simple_study_no_gt(study_folder):
4040
"sorter_name": "spykingcircus2",
4141
},
4242
},
43-
4443
}
4544

46-
study = SorterStudyWithoutGroundTruth.create(
47-
study_folder, datasets=datasets, cases=cases
48-
)
45+
study = SorterStudyWithoutGroundTruth.create(study_folder, datasets=datasets, cases=cases)
4946
# print(study)
5047

48+
5149
@pytest.mark.skip()
5250
def test_SorterStudyWithoutGroundTruth(create_simple_study):
5351
# job_kwargs = dict(n_jobs=2, chunk_duration="1s")
@@ -64,9 +62,10 @@ def test_SorterStudyWithoutGroundTruth(create_simple_study):
6462
print(study)
6563

6664

67-
6865
if __name__ == "__main__":
69-
study_folder_simple = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_SorterStudyWithoutGroundTruth"
66+
study_folder_simple = (
67+
Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_SorterStudyWithoutGroundTruth"
68+
)
7069
print(study_folder_simple)
7170
if study_folder_simple.exists():
7271
shutil.rmtree(study_folder_simple)

src/spikeinterface/benchmark/tests/test_residual_analysis.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010

1111
job_kwargs = dict(n_jobs=-1, progress_bar=True)
1212

13+
1314
@pytest.mark.skip()
1415
def test_analyse_residual():
15-
_, _ , analyzer = make_dataset()
16+
_, _, analyzer = make_dataset()
1617
if not analyzer.has_extension("amplitude_scalings"):
1718
analyzer.compute("amplitude_scalings", **job_kwargs)
1819
print(analyzer)

src/spikeinterface/comparison/comparisontools.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -461,19 +461,19 @@ def make_possible_match(agreement_scores, min_score):
461461
def _empty_match_series(unit1_ids, unit2_ids):
462462
# construct empty series of match with the correct dtype for best match and hugarian match
463463
import pandas as pd
464+
464465
match_12 = pd.Series(data=np.zeros(unit1_ids.size, dtype=unit2_ids.dtype), index=unit1_ids)
465-
if unit2_ids.dtype.kind == 'i':
466+
if unit2_ids.dtype.kind == "i":
466467
match_12[:] = -1
467-
elif unit2_ids.dtype.kind == 'U':
468-
match_12[:] = ''
469-
elif unit2_ids.dtype.kind == 'O':
470-
match_12[:] = ''
468+
elif unit2_ids.dtype.kind == "U":
469+
match_12[:] = ""
470+
elif unit2_ids.dtype.kind == "O":
471+
match_12[:] = ""
471472
else:
472473
raise ValueError("make_best_match or make_hungarian_match has unit_ids dtype wich are not 'i' or 'U'")
473474
return match_12
474475

475476

476-
477477
def make_best_match(agreement_scores, min_score) -> "tuple[pd.Series, pd.Series]":
478478
"""
479479
Given an agreement matrix and a min_score threshold.

src/spikeinterface/comparison/paircomparisons.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def get_matching_unit_list2(self, unit2):
203203
return self.possible_match_21[unit2]
204204

205205
def get_agreement_fraction(self, unit1=None, unit2=None):
206-
if unit1 is None or unit1 == -1 or unit1 == '' or unit2 is None or unit2 == -1 or unit2 == '':
206+
if unit1 is None or unit1 == -1 or unit1 == "" or unit2 is None or unit2 == -1 or unit2 == "":
207207
return 0
208208
else:
209209
return self.agreement_scores.at[unit1, unit2]
@@ -571,7 +571,7 @@ def get_false_positive_units(self, redundant_score=None):
571571
false_positive_ids = []
572572
for u2 in self.unit2_ids:
573573
if u2 not in matched_units2:
574-
if self.best_match_21[u2] == -1 or self.best_match_21[u2] == '':
574+
if self.best_match_21[u2] == -1 or self.best_match_21[u2] == "":
575575
false_positive_ids.append(u2)
576576
else:
577577
u1 = self.best_match_21[u2]
@@ -615,7 +615,7 @@ def get_redundant_units(self, redundant_score=None):
615615
matched_units2 = list(self.hungarian_match_12.values)
616616
redundant_ids = []
617617
for u2 in self.unit2_ids:
618-
if u2 not in matched_units2 and self.best_match_21[u2] != -1 and self.best_match_21[u2] != '':
618+
if u2 not in matched_units2 and self.best_match_21[u2] != -1 and self.best_match_21[u2] != "":
619619
u1 = self.best_match_21[u2]
620620
if u2 != self.best_match_12[u1]:
621621
score = self.agreement_scores.at[u1, u2]

0 commit comments

Comments
 (0)