Skip to content

Commit 4c74630

Browse files
ygersamuelgarciapre-commit-ci[bot]
authored
Add the option for clustering methods to return temporal shifts for peaks while clustering (#4401)
Co-authored-by: Samuel Garcia <sam.garcia.die@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent af2718d commit 4c74630

6 files changed

Lines changed: 55 additions & 32 deletions

File tree

src/spikeinterface/sorters/internal/lupin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
308308
job_kwargs=job_kwargs,
309309
)
310310

311+
if more_outs["time_shifts"] is not None:
312+
time_shifts = more_outs["time_shifts"]
313+
peaks["sample_index"] += time_shifts
314+
311315
mask = clustering_label >= 0
312316
kept_peaks = peaks[mask]
313317
kept_labels = clustering_label[mask]

src/spikeinterface/sorters/internal/tridesclous2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
348348
sorting = NumpySorting(final_spikes, sampling_frequency, templates.unit_ids)
349349

350350
auto_merge = True
351+
351352
analyzer_final = None
352353
if auto_merge:
353354
from spikeinterface.sorters.internal.spyking_circus2 import final_cleaning_circus

src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,15 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
164164
print("Kept %d raw clusters" % len(labels))
165165

166166
if params["merge_from_templates"] is not None:
167-
peak_labels, merge_template_array, new_sparse_mask, new_unit_ids = merge_peak_labels_from_templates(
168-
peaks,
169-
peak_labels,
170-
templates.unit_ids,
171-
templates.templates_array,
172-
new_sparse_mask,
173-
**params["merge_from_templates"],
167+
peak_labels, merge_template_array, new_sparse_mask, new_unit_ids, time_shifts = (
168+
merge_peak_labels_from_templates(
169+
peaks,
170+
peak_labels,
171+
templates.unit_ids,
172+
templates.templates_array,
173+
new_sparse_mask,
174+
**params["merge_from_templates"],
175+
)
174176
)
175177

176178
templates = Templates(
@@ -183,6 +185,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
183185
probe=recording.get_probe(),
184186
is_in_uV=False,
185187
)
188+
else:
189+
time_shifts = None
186190

187191
# clean very small cluster before peeler
188192
if (
@@ -210,6 +214,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
210214
more_outs = dict(
211215
svd_model=svd_model,
212216
peaks_svd=peaks_svd,
217+
time_shifts=time_shifts,
213218
peak_svd_sparse_mask=sparse_mask,
214219
)
215220
return labels, peak_labels, more_outs

src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -299,16 +299,19 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
299299
num_shifts = params_merge_from_templates["num_shifts"]
300300
num_shifts = min((num_shifts, nbefore, nafter))
301301
params_merge_from_templates["num_shifts"] = num_shifts
302-
post_merge_label2, templates_array, template_sparse_mask, unit_ids = merge_peak_labels_from_templates(
303-
peaks,
304-
post_merge_label1,
305-
unit_ids,
306-
templates_array,
307-
template_sparse_mask,
308-
**params_merge_from_templates,
302+
post_merge_label2, templates_array, template_sparse_mask, unit_ids, time_shifts = (
303+
merge_peak_labels_from_templates(
304+
peaks,
305+
post_merge_label1,
306+
unit_ids,
307+
templates_array,
308+
template_sparse_mask,
309+
**params_merge_from_templates,
310+
)
309311
)
310312
else:
311313
post_merge_label2 = post_merge_label1.copy()
314+
time_shifts = None
312315

313316
dense_templates = Templates(
314317
templates_array=templates_array,
@@ -343,7 +346,5 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
343346

344347
labels_set = templates.unit_ids
345348

346-
more_outs = dict(
347-
templates=templates,
348-
)
349+
more_outs = dict(templates=templates, time_shifts=time_shifts)
349350
return labels_set, final_peak_labels, more_outs

src/spikeinterface/sortingcomponents/clustering/merging_tools.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -556,13 +556,13 @@ def merge_peak_labels_from_templates(
556556
if not use_lags:
557557
lags = None
558558

559-
clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids = (
559+
clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids, time_shifts = (
560560
_apply_pair_mask_on_labels_and_recompute_templates(
561561
pair_mask, peak_labels, unit_ids, templates_array, template_sparse_mask, lags
562562
)
563563
)
564564

565-
return clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids
565+
return clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids, time_shifts
566566

567567

568568
def _apply_pair_mask_on_labels_and_recompute_templates(
@@ -580,7 +580,10 @@ def _apply_pair_mask_on_labels_and_recompute_templates(
580580
clean_labels = peak_labels.copy()
581581
n_components, group_labels = connected_components(pair_mask, directed=False, return_labels=True)
582582

583-
# print("merges", templates_array.shape[0], "to", n_components)
583+
if lags is not None:
584+
time_shifts = np.zeros(len(peak_labels), dtype=np.int32)
585+
else:
586+
time_shifts = None
584587

585588
merge_template_array = templates_array.copy()
586589
merge_sparsity_mask = template_sparse_mask.copy()
@@ -603,10 +606,15 @@ def _apply_pair_mask_on_labels_and_recompute_templates(
603606

604607
for i, l in enumerate(merge_group):
605608
label = unit_ids[l]
606-
weights[i] = np.sum(peak_labels == label)
609+
mask = peak_labels == label
610+
weights[i] = np.sum(mask)
607611
if i > 0:
608-
clean_labels[peak_labels == label] = unit_ids[g0]
612+
clean_labels[mask] = unit_ids[g0]
609613
keep_template[l] = False
614+
if lags is not None:
615+
shift = lags[l, g0] # which is the same as -lags[g0, l]
616+
time_shifts[mask] += shift
617+
610618
weights /= weights.sum()
611619

612620
if lags is None:
@@ -617,7 +625,7 @@ def _apply_pair_mask_on_labels_and_recompute_templates(
617625
# with shifts
618626
accumulated_template = np.zeros_like(merge_template_array[g0, :, :])
619627
for i, l in enumerate(merge_group):
620-
shift = -lags[g0, l]
628+
shift = lags[l, g0] # which is the same as -lags[g0, l]
621629
if shift > 0:
622630
# template is shifted to right
623631
temp = np.zeros_like(accumulated_template)
@@ -637,4 +645,4 @@ def _apply_pair_mask_on_labels_and_recompute_templates(
637645
merge_template_array = merge_template_array[keep_template, :, :]
638646
merge_sparsity_mask = merge_sparsity_mask[keep_template, :]
639647

640-
return clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids
648+
return clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids, time_shifts

src/spikeinterface/sortingcomponents/clustering/random_projections.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,15 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
132132
print("Kept %d raw clusters" % len(labels))
133133

134134
if params["merge_from_templates"] is not None:
135-
peak_labels, merge_template_array, new_sparse_mask, new_unit_ids = merge_peak_labels_from_templates(
136-
peaks,
137-
peak_labels,
138-
unit_ids,
139-
templates_array,
140-
np.ones((len(unit_ids), num_chans), dtype=bool),
141-
**params["merge_from_templates"],
135+
peak_labels, merge_template_array, new_sparse_mask, new_unit_ids, time_shifts = (
136+
merge_peak_labels_from_templates(
137+
peaks,
138+
peak_labels,
139+
unit_ids,
140+
templates_array,
141+
np.ones((len(unit_ids), num_chans), dtype=bool),
142+
**params["merge_from_templates"],
143+
)
142144
)
143145

144146
templates = Templates(
@@ -151,6 +153,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
151153
probe=recording.get_probe(),
152154
is_in_uV=False,
153155
)
156+
else:
157+
time_shifts = None
154158

155159
labels = templates.unit_ids
156160

@@ -160,4 +164,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
160164
if verbose:
161165
print("Kept %d non-duplicated clusters" % len(labels))
162166

163-
return labels, peak_labels, dict()
167+
return labels, peak_labels, dict(time_shifts=time_shifts, templates=templates)

0 commit comments

Comments
 (0)