Skip to content

Commit c398228

Browse files
authored
Merge pull request #3923 from samuelgarcia/components_merge_templates
Update TDC2 and SC2 with various improvement in sorting components (clustering + matching)
2 parents f2d6373 + 1af1ebb commit c398228

39 files changed

Lines changed: 1210 additions & 708 deletions

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,15 @@ test = [
161161
# streaming templates
162162
"s3fs",
163163

164-
# tridesclous
164+
# tridesclous2
165165
"numba<0.61.0;python_version<'3.13'",
166166
"numba>=0.61.0;python_version>='3.13'",
167167
"hdbscan>=0.8.33", # Previous version had a broken wheel
168168

169+
# isosplit is needed for trideclous2 noramaly but isosplit is only build until python3.11
170+
# so lets wait a new build of isosplit6
171+
# "isosplit6",
172+
169173
# for sortingview backend
170174
"sortingview>=0.12.0",
171175

src/spikeinterface/benchmark/benchmark_base.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import time
1010

1111

12+
from spikeinterface.core import SortingAnalyzer, ChannelSparsity, NumpySorting
1213
from spikeinterface.core import SortingAnalyzer, NumpySorting
1314
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
1415
from spikeinterface import load, create_sorting_analyzer, load_sorting_analyzer
@@ -46,6 +47,7 @@ def __init__(self, study_folder):
4647
self.levels = None
4748
self.colors_by_case = None
4849
self.colors_by_levels = {}
50+
self.labels_by_levels = {}
4951
self.scan_folder()
5052

5153
@classmethod
@@ -120,8 +122,41 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None):
120122
rec, gt_sorting = data
121123

122124
if gt_sorting is not None:
125+
if "gt_unit_locations" in gt_sorting.get_property_keys():
126+
# if real units locations is present then use it for a better sparsity
127+
# then the real max channel is used
128+
radius_um = 100.0
129+
channel_ids = rec.channel_ids
130+
unit_ids = gt_sorting.unit_ids
131+
gt_unit_locations = gt_sorting.get_property("gt_unit_locations")
132+
channel_locations = rec.get_channel_locations()
133+
max_channel_indices = np.argmin(
134+
np.linalg.norm(
135+
gt_unit_locations[:, np.newaxis, :2] - channel_locations[np.newaxis, :], axis=2
136+
),
137+
axis=1,
138+
)
139+
mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool")
140+
distances = np.linalg.norm(
141+
channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2
142+
)
143+
for unit_ind, unit_id in enumerate(unit_ids):
144+
chan_ind = max_channel_indices[unit_ind]
145+
(chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um)
146+
mask[unit_ind, chan_inds] = True
147+
sparsity = ChannelSparsity(mask, unit_ids, channel_ids)
148+
sparse = False
149+
else:
150+
sparse = True
151+
sparsity = None
152+
123153
analyzer = create_sorting_analyzer(
124-
gt_sorting, rec, sparse=True, format="binary_folder", folder=local_analyzer_folder
154+
gt_sorting,
155+
rec,
156+
sparse=sparse,
157+
sparsity=sparsity,
158+
format="binary_folder",
159+
folder=local_analyzer_folder,
125160
)
126161
analyzer.compute("random_spikes")
127162
analyzer.compute("templates")
@@ -135,6 +170,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None):
135170
analyzer = create_sorting_analyzer(
136171
gt_sorting, rec, sparse=False, format="binary_folder", folder=local_analyzer_folder
137172
)
173+
138174
else:
139175
# new case : analzyer
140176
assert isinstance(data, SortingAnalyzer)
@@ -254,6 +290,8 @@ def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs):
254290

255291
for key in job_keys:
256292
benchmark = self.create_benchmark(key)
293+
if verbose:
294+
print("### Run benchmark", key, "###")
257295
t0 = time.perf_counter()
258296
benchmark.run(**job_kwargs)
259297
t1 = time.perf_counter()
@@ -324,14 +362,16 @@ def get_run_times(self, case_keys=None):
324362
df.index.names = self.levels
325363
return df
326364

327-
def get_grouped_keys_mapping(self, levels_to_group_by=None):
365+
def get_grouped_keys_mapping(self, levels_to_group_by=None, case_keys=None):
328366
"""
329367
Return a dictionary of grouped keys.
330368
331369
Parameters
332370
----------
333371
levels_to_group_by : list
334372
A list of levels to group by.
373+
case_keys : list
374+
Optionaly a sub list of case_keys to consider
335375
336376
Returns
337377
-------
@@ -341,18 +381,19 @@ def get_grouped_keys_mapping(self, levels_to_group_by=None):
341381
labels : dict
342382
A dictionary of labels, with the new keys as keys and the labels as values.
343383
"""
344-
cases = list(self.cases.keys())
384+
if case_keys is None:
385+
case_keys = list(self.cases.keys())
345386
if levels_to_group_by is None or self.levels is None:
346-
keys_mapping = {key: [key] for key in cases}
387+
keys_mapping = {key: [key] for key in case_keys}
347388
elif len(self.levels) == 1:
348-
keys_mapping = {key: [key] for key in cases}
389+
keys_mapping = {key: [key] for key in case_keys}
349390
else:
350391
study_levels = self.levels
351392
assert np.all(
352393
[l in study_levels for l in levels_to_group_by]
353394
), f"levels_to_group_by must be in {study_levels}, got {levels_to_group_by}"
354395
keys_mapping = {}
355-
for key in cases:
396+
for key in case_keys:
356397
new_key = tuple(key[list(study_levels).index(level)] for level in levels_to_group_by)
357398
if len(new_key) == 1:
358399
new_key = new_key[0]
@@ -361,13 +402,17 @@ def get_grouped_keys_mapping(self, levels_to_group_by=None):
361402
keys_mapping[new_key].append(key)
362403

363404
if levels_to_group_by is None:
364-
labels = {key: self.cases[key]["label"] for key in cases}
405+
labels = {key: self.cases[key]["label"] for key in case_keys}
365406
else:
366-
key0 = list(keys_mapping.keys())[0]
367-
if isinstance(key0, tuple):
368-
labels = {key: "-".join(key) for key in keys_mapping}
407+
level_key = tuple(levels_to_group_by) if len(levels_to_group_by) > 1 else levels_to_group_by[0]
408+
if level_key in self.labels_by_levels:
409+
labels = self.labels_by_levels[level_key]
369410
else:
370-
labels = {key: key for key in keys_mapping}
411+
key0 = list(keys_mapping.keys())[0]
412+
if isinstance(key0, tuple):
413+
labels = {key: "-".join(key) for key in keys_mapping}
414+
else:
415+
labels = {key: key for key in keys_mapping}
371416

372417
return keys_mapping, labels
373418

@@ -383,6 +428,8 @@ def compute_results(self, case_keys=None, verbose=False, **result_params):
383428

384429
job_keys = []
385430
for key in case_keys:
431+
if verbose:
432+
print("### Compute result", key, "###")
386433
benchmark = self.benchmarks[key]
387434
assert benchmark is not None
388435
benchmark.compute_result(**result_params)
@@ -438,9 +485,9 @@ def get_gt_unit_locations(self, case_key):
438485
unit_locations_ext = sorting_analyzer.get_extension("unit_locations")
439486
return unit_locations_ext.get_data()
440487

441-
def get_templates(self, key, operator="average"):
488+
def get_templates(self, key, operator="average", outputs="numpy"):
442489
sorting_analyzer = self.get_sorting_analyzer(case_key=key)
443-
templates = sorting_analyzer.get_extenson("templates").get_data(operator=operator)
490+
templates = sorting_analyzer.get_extension("templates").get_data(operator=operator, outputs=outputs)
444491
return templates
445492

446493
def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], force=False, **job_kwargs):
@@ -668,15 +715,15 @@ def get_count_units(self, case_keys=None, well_detected_score=None, redundant_sc
668715
gt_sorting = comp.sorting1
669716
sorting = comp.sorting2
670717

671-
count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids())
672-
count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids())
673-
count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score)
718+
count_units.at[key, "num_gt"] = len(gt_sorting.get_unit_ids())
719+
count_units.at[key, "num_sorter"] = len(sorting.get_unit_ids())
720+
count_units.at[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score)
674721

675722
if comp.exhaustive_gt:
676-
count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score)
677-
count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score)
678-
count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score)
679-
count_units.loc[key, "num_bad"] = comp.count_bad_units()
723+
count_units.at[key, "num_redundant"] = comp.count_redundant_units(redundant_score)
724+
count_units.at[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score)
725+
count_units.at[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score)
726+
count_units.at[key, "num_bad"] = comp.count_bad_units()
680727

681728
return count_units
682729

src/spikeinterface/benchmark/benchmark_clustering.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def run(self, **job_kwargs):
3535
)
3636
self.result["peak_labels"] = peak_labels
3737

38-
def compute_result(self, **result_params):
39-
result_params, job_kwargs = split_job_kwargs(result_params)
38+
def compute_result(self, with_template=False, **job_kwargs):
39+
# result_params, job_kwargs = split_job_kwargs(result_params)
4040
job_kwargs = fix_job_kwargs(job_kwargs)
4141
self.noise = self.result["peak_labels"] < 0
4242
spikes = self.gt_sorting.to_spike_vector()
@@ -68,19 +68,20 @@ def compute_result(self, **result_params):
6868
self.result["sliced_gt_sorting"], self.result["clustering"], exhaustive_gt=self.exhaustive_gt
6969
)
7070

71-
sorting_analyzer = create_sorting_analyzer(
72-
self.result["sliced_gt_sorting"], self.recording, format="memory", sparse=False, **job_kwargs
73-
)
74-
sorting_analyzer.compute("random_spikes")
75-
ext = sorting_analyzer.compute("templates", **job_kwargs)
76-
self.result["sliced_gt_templates"] = ext.get_data(outputs="Templates")
71+
if with_template:
72+
sorting_analyzer = create_sorting_analyzer(
73+
self.result["sliced_gt_sorting"], self.recording, format="memory", sparse=False, **job_kwargs
74+
)
75+
sorting_analyzer.compute("random_spikes")
76+
ext = sorting_analyzer.compute("templates", **job_kwargs)
77+
self.result["sliced_gt_templates"] = ext.get_data(outputs="Templates")
7778

78-
sorting_analyzer = create_sorting_analyzer(
79-
self.result["clustering"], self.recording, format="memory", sparse=False, **job_kwargs
80-
)
81-
sorting_analyzer.compute("random_spikes")
82-
ext = sorting_analyzer.compute("templates", **job_kwargs)
83-
self.result["clustering_templates"] = ext.get_data(outputs="Templates")
79+
sorting_analyzer = create_sorting_analyzer(
80+
self.result["clustering"], self.recording, format="memory", sparse=False, **job_kwargs
81+
)
82+
sorting_analyzer.compute("random_spikes")
83+
ext = sorting_analyzer.compute("templates", **job_kwargs)
84+
self.result["clustering_templates"] = ext.get_data(outputs="Templates")
8485

8586
_run_key_saved = [("peak_labels", "npy")]
8687

src/spikeinterface/benchmark/benchmark_matching.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,24 +91,35 @@ def plot_performances_ordered(self, *args, **kwargs):
9191

9292
return plot_performances_ordered(self, *args, **kwargs)
9393

94-
def plot_collisions(self, case_keys=None, figsize=None):
94+
def plot_collisions(self, case_keys=None, metric="l2", mode="lines", show_legend=True, axs=None, figsize=None):
9595
if case_keys is None:
9696
case_keys = list(self.cases.keys())
9797
import matplotlib.pyplot as plt
9898

99-
fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
99+
if axs is None:
100+
fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
101+
axs = axs[0, :]
102+
else:
103+
fig = axs[0].figure
100104

101105
for count, key in enumerate(case_keys):
102-
templates_array = self.get_result(key)["templates"].templates_array
106+
label = self.cases[key]["label"]
107+
templates_array = self.get_sorting_analyzer(key).get_extension("templates").get_templates(outputs="numpy")
108+
ax = axs[count]
103109
plot_comparison_collision_by_similarity(
104110
self.get_result(key)["gt_collision"],
105111
templates_array,
106-
ax=axs[0, count],
107-
show_legend=True,
108-
mode="lines",
109-
good_only=False,
112+
metric=metric,
113+
ax=ax,
114+
show_legend=show_legend,
115+
mode=mode,
116+
# good_only=False,
117+
# good_only=False,
118+
good_only=True,
110119
)
111120

121+
ax.set_title(label)
122+
112123
return fig
113124

114125
def plot_unit_counts(self, case_keys=None, **kwargs):

0 commit comments

Comments
 (0)