Skip to content

Commit f6e0832

Browse files
ygersamuelgarcia
andauthored
Final cleaning of mixtures for lupin/tdc/sc (#4244)
Co-authored-by: Samuel Garcia <sam.garcia.die@gmail.com>
1 parent 402a862 commit f6e0832

9 files changed

Lines changed: 279 additions & 222 deletions

File tree

src/spikeinterface/benchmark/benchmark_clustering.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,14 @@ def __init__(self, recording, gt_sorting, params, indices, peaks, exhaustive_gt=
2929
self.method_kwargs = params["method_kwargs"]
3030
self.result = {}
3131

32-
def run(self, **job_kwargs):
32+
def run(self, verbose=True, **job_kwargs):
3333
labels, peak_labels = find_clusters_from_peaks(
34-
self.recording, self.peaks, method=self.method, method_kwargs=self.method_kwargs, job_kwargs=job_kwargs
34+
self.recording,
35+
self.peaks,
36+
method=self.method,
37+
method_kwargs=self.method_kwargs,
38+
verbose=verbose,
39+
job_kwargs=job_kwargs,
3540
)
3641
self.result["peak_labels"] = peak_labels
3742

src/spikeinterface/benchmark/benchmark_matching.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,14 @@ def __init__(self, recording, gt_sorting, params):
2626
self.method_kwargs = params["method_kwargs"]
2727
self.result = {}
2828

29-
def run(self, **job_kwargs):
29+
def run(self, verbose=True, **job_kwargs):
3030
spikes = find_spikes_from_templates(
31-
self.recording, self.templates, method=self.method, method_kwargs=self.method_kwargs, job_kwargs=job_kwargs
31+
self.recording,
32+
self.templates,
33+
method=self.method,
34+
method_kwargs=self.method_kwargs,
35+
verbose=verbose,
36+
job_kwargs=job_kwargs,
3237
)
3338
unit_ids = self.templates.unit_ids
3439
sorting = np.zeros(spikes.size, dtype=minimum_spike_dtype)

src/spikeinterface/sorters/internal/lupin.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@ class LupinSorter(ComponentsBasedSorter):
5454
"clustering_recursive_depth": 3,
5555
"ms_before": 1.0,
5656
"ms_after": 2.5,
57-
"sparsity_threshold": 1.5,
58-
"template_min_snr": 2.5,
57+
"template_sparsify_threshold": 1.5,
58+
"template_min_snr_ptp": 4.0,
59+
"template_max_jitter_ms": 0.2,
60+
"min_firing_rate": 0.1,
5961
"gather_mode": "memory",
6062
"job_kwargs": {},
6163
"seed": None,
@@ -80,8 +82,10 @@ class LupinSorter(ComponentsBasedSorter):
8082
"clustering_recursive_depth": "Clustering recussivity",
8183
"ms_before": "Milliseconds before the spike peak for template matching",
8284
"ms_after": "Milliseconds after the spike peak for template matching",
83-
"sparsity_threshold": "Threshold to sparsify templates before template matching",
84-
"template_min_snr": "Threshold to remove templates before template matching",
85+
"template_sparsify_threshold": "Threshold to sparsify templates before template matching",
86+
"template_min_snr_ptp": "Threshold to remove templates before template matching",
87+
"template_max_jitter_ms": "Threshold on jitters to remove templates before template matching",
88+
"min_firing_rate": "To remove small cluster in size before template matching",
8589
"gather_mode": "How to accumalte spike in matching : memory/npy",
8690
"job_kwargs": "The famous and fabulous job_kwargs",
8791
"seed": "Seed for random number",
@@ -232,6 +236,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
232236
clustering_kwargs["peaks_svd"]["n_components"] = params["n_svd_components_per_channel"]
233237
clustering_kwargs["split"]["recursive_depth"] = params["clustering_recursive_depth"]
234238
clustering_kwargs["split"]["method_kwargs"]["n_pca_features"] = params["n_pca_features"]
239+
clustering_kwargs["clean_templates"]["sparsify_threshold"] = params["template_sparsify_threshold"]
240+
clustering_kwargs["clean_templates"]["min_snr"] = params["template_min_snr_ptp"]
241+
clustering_kwargs["clean_templates"]["max_jitter_ms"] = params["template_max_jitter_ms"]
242+
clustering_kwargs["noise_levels"] = noise_levels
243+
clustering_kwargs["clean_low_firing"]["min_firing_rate"] = params["min_firing_rate"]
244+
clustering_kwargs["clean_low_firing"]["subsampling_factor"] = all_peaks.size / peaks.size
235245

236246
if params["debug"]:
237247
clustering_kwargs["debug_folder"] = sorter_output_folder
@@ -290,10 +300,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
290300
# this spasify more
291301
templates = clean_templates(
292302
templates,
293-
sparsify_threshold=params["sparsity_threshold"],
303+
sparsify_threshold=params["template_sparsify_threshold"],
294304
noise_levels=noise_levels,
295-
min_snr=params["template_min_snr"],
296-
max_jitter_ms=None,
305+
min_snr=params["template_min_snr_ptp"],
306+
max_jitter_ms=params["template_max_jitter_ms"],
297307
remove_empty=True,
298308
)
299309

src/spikeinterface/sorters/internal/spyking_circus2.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
3939
"motion_correction": {"preset": "dredge_fast"},
4040
"merging": {"max_distance_um": 50},
4141
"clustering": {"method": "iterative-hdbscan", "method_kwargs": dict()},
42-
"cleaning": {"min_snr": 5, "max_jitter_ms": 0.1, "sparsify_threshold": None},
42+
"cleaning": {"min_snr": 5, "max_jitter_ms": 0.2, "sparsify_threshold": 1, "mean_sd_ratio_threshold": 3},
43+
"min_firing_rate": 0.1,
4344
"matching": {"method": "circus-omp", "method_kwargs": dict(), "pipeline_kwargs": dict()},
4445
"apply_preprocessing": True,
4546
"apply_whitening": True,
@@ -103,6 +104,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
103104
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
104105
from spikeinterface.sortingcomponents.peak_selection import select_peaks
105106
from spikeinterface.sortingcomponents.clustering import find_clusters_from_peaks
107+
from spikeinterface.sortingcomponents.clustering.tools import remove_small_cluster
106108
from spikeinterface.sortingcomponents.matching import find_spikes_from_templates
107109
from spikeinterface.sortingcomponents.tools import check_probe_for_drift_correction
108110
from spikeinterface.sortingcomponents.tools import clean_templates
@@ -118,8 +120,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
118120
ms_before = params["general"].get("ms_before", 0.5)
119121
ms_after = params["general"].get("ms_after", 1.5)
120122
radius_um = params["general"].get("radius_um", 100.0)
121-
detect_threshold = params["detection"]["method_kwargs"].get("detect_threshold", 5)
122-
peak_sign = params["detection"].get("peak_sign", "neg")
123123
deterministic = params["deterministic_peaks_detection"]
124124
debug = params["debug"]
125125
seed = params["seed"]
@@ -310,6 +310,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
310310
if verbose:
311311
print("Kept %d peaks for clustering" % len(selected_peaks))
312312

313+
cleaning_kwargs = params.get("cleaning", {}).copy()
314+
cleaning_kwargs["remove_empty"] = True
315+
313316
if clustering_method in [
314317
"iterative-hdbscan",
315318
"iterative-isosplit",
@@ -319,6 +322,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
319322
clustering_params.update(verbose=verbose)
320323
clustering_params.update(seed=seed)
321324
clustering_params.update(peaks_svd=params["general"])
325+
if clustering_method in ["iterative-hdbscan", "iterative-isosplit"]:
326+
clustering_params.update(clean_templates=cleaning_kwargs)
327+
clustering_params["noise_levels"] = noise_levels
328+
322329
if debug:
323330
clustering_params["debug_folder"] = sorter_output_folder / "clustering"
324331

@@ -328,6 +335,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
328335
method=clustering_method,
329336
method_kwargs=clustering_params,
330337
extra_outputs=True,
338+
verbose=verbose,
331339
job_kwargs=job_kwargs,
332340
)
333341

@@ -365,7 +373,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
365373
else:
366374
from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd
367375

368-
dense_templates, new_sparse_mask = get_templates_from_peaks_and_svd(
376+
dense_templates, new_sparse_mask, max_std_per_channel = get_templates_from_peaks_and_svd(
369377
recording_w,
370378
selected_peaks,
371379
peak_labels,
@@ -375,16 +383,30 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
375383
more_outs["peaks_svd"],
376384
more_outs["peak_svd_sparse_mask"],
377385
operator="median",
386+
return_max_std_per_channel=True,
378387
)
379388
# this release the peak_svd memmap file
380389
templates = dense_templates.to_sparse(new_sparse_mask)
381390

382391
del more_outs
383392

384-
cleaning_kwargs = params.get("cleaning", {}).copy()
385-
cleaning_kwargs["noise_levels"] = noise_levels
386-
cleaning_kwargs["remove_empty"] = True
387-
templates = clean_templates(templates, **cleaning_kwargs)
393+
before_clean_ids = templates.unit_ids.copy()
394+
cleaning_kwargs["max_std_per_channel"] = max_std_per_channel
395+
cleaning_kwargs["verbose"] = verbose
396+
templates = clean_templates(templates, noise_levels=noise_levels, **cleaning_kwargs)
397+
remove_peak_mask = ~np.isin(peak_labels, templates.unit_ids)
398+
peak_labels[remove_peak_mask] = -1
399+
400+
if params["min_firing_rate"] is not None:
401+
peak_labels, to_keep = remove_small_cluster(
402+
recording_w,
403+
selected_peaks,
404+
peak_labels,
405+
min_firing_rate=params["min_firing_rate"],
406+
subsampling_factor=peaks.size / selected_peaks.size,
407+
verbose=verbose,
408+
)
409+
templates = templates.select_units(to_keep)
388410

389411
if verbose:
390412
print("Kept %d clean clusters" % len(templates.unit_ids))
@@ -508,5 +530,4 @@ def final_cleaning_circus(
508530
sparsity_overlap=sparsity_overlap,
509531
**job_kwargs,
510532
)
511-
512533
return final_sa

src/spikeinterface/sorters/internal/tridesclous2.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
5151
"clustering": {
5252
"recursive_depth": 3,
5353
},
54+
"min_firing_rate": 0.1,
5455
"templates": {
5556
"ms_before": 2.0,
5657
"ms_after": 3.0,
5758
"max_spikes_per_unit": 400,
5859
"sparsity_threshold": 1.5,
59-
"min_snr": 2.5,
60+
"min_snr": 3.5,
6061
"radius_um": 100.0,
62+
"max_jitter_ms": 0.2,
6163
},
6264
"matching": {"method": "tdc-peeler", "method_kwargs": {}, "gather_mode": "memory"},
6365
"job_kwargs": {},
@@ -93,7 +95,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
9395
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
9496
from spikeinterface.sortingcomponents.peak_selection import select_peaks
9597
from spikeinterface.sortingcomponents.clustering.main import find_clusters_from_peaks, clustering_methods
96-
from spikeinterface.sortingcomponents.tools import remove_empty_templates
9798
from spikeinterface.preprocessing import correct_motion
9899
from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording
99100
from spikeinterface.sortingcomponents.tools import clean_templates, compute_sparsity_from_peaks_and_label
@@ -194,6 +195,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
194195
clustering_kwargs["split"].update(params["clustering"])
195196
if params["debug"]:
196197
clustering_kwargs["debug_folder"] = sorter_output_folder
198+
clustering_kwargs["noise_levels"] = noise_levels
199+
clustering_kwargs["clean_low_firing"]["min_firing_rate"] = params["min_firing_rate"]
200+
clustering_kwargs["clean_low_firing"]["subsampling_factor"] = all_peaks.size / peaks.size
197201

198202
# if clustering_kwargs["clustering"]["clusterer"] == "isosplit6":
199203
# have_sisosplit6 = importlib.util.find_spec("isosplit6") is not None
@@ -262,13 +266,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
262266
is_in_uV=False,
263267
)
264268

265-
# this spasify more
269+
# this clean and spasify more
266270
templates = clean_templates(
267271
templates,
268272
sparsify_threshold=params["templates"]["sparsity_threshold"],
269273
noise_levels=noise_levels,
270274
min_snr=params["templates"]["min_snr"],
271-
max_jitter_ms=None,
275+
max_jitter_ms=params["templates"]["max_jitter_ms"],
272276
remove_empty=True,
273277
)
274278

src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
from spikeinterface.sortingcomponents.waveforms.peak_svd import extract_peaks_svd
1010
from spikeinterface.sortingcomponents.clustering.merging_tools import merge_peak_labels_from_templates
1111
from spikeinterface.sortingcomponents.clustering.itersplit_tools import split_clusters
12-
from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd
12+
from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd, remove_small_cluster
13+
from spikeinterface.sortingcomponents.tools import clean_templates
14+
from spikeinterface.core.recording_tools import get_noise_levels
1315

1416

1517
class IterativeHDBSCANClustering:
@@ -30,6 +32,7 @@ class IterativeHDBSCANClustering:
3032
_default_params = {
3133
"peaks_svd": {"n_components": 5, "ms_before": 0.5, "ms_after": 1.5, "radius_um": 100.0},
3234
"seed": None,
35+
"noise_levels": None,
3336
"split": {
3437
"split_radius_um": 75.0,
3538
"recursive": True,
@@ -43,8 +46,18 @@ class IterativeHDBSCANClustering:
4346
"n_pca_features": 3,
4447
},
4548
},
49+
"clean_templates": {
50+
"sparsify_threshold": 1.0,
51+
"min_snr": 2.5,
52+
"remove_empty": True,
53+
"max_jitter_ms": 0.2,
54+
},
4655
"merge_from_templates": dict(similarity_thresh=0.8, num_shifts=3, use_lags=True),
4756
"merge_from_features": None,
57+
"clean_low_firing": {
58+
"min_firing_rate": 0.1,
59+
"subsampling_factor": None,
60+
},
4861
"debug_folder": None,
4962
"verbose": True,
5063
}
@@ -116,7 +129,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
116129
**split,
117130
)
118131

119-
templates, new_sparse_mask = get_templates_from_peaks_and_svd(
132+
templates, new_sparse_mask, max_std_per_channel = get_templates_from_peaks_and_svd(
120133
recording,
121134
peaks,
122135
peak_labels,
@@ -126,8 +139,27 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
126139
peaks_svd,
127140
sparse_mask,
128141
operator="median",
142+
return_max_std_per_channel=True,
129143
)
130144

145+
## Pre clean using templates (jitter, sparsify_threshold)
146+
templates = templates.to_sparse(new_sparse_mask)
147+
cleaning_kwargs = params["clean_templates"].copy()
148+
cleaning_kwargs["verbose"] = verbose
149+
cleaning_kwargs["max_std_per_channel"] = max_std_per_channel
150+
if params["noise_levels"] is not None:
151+
noise_levels = params["noise_levels"]
152+
else:
153+
noise_levels = get_noise_levels(recording, return_in_uV=False, **job_kwargs)
154+
cleaning_kwargs["noise_levels"] = noise_levels
155+
cleaned_templates = clean_templates(templates, **cleaning_kwargs)
156+
mask_keep_ids = np.isin(templates.unit_ids, cleaned_templates.unit_ids)
157+
to_remove_ids = templates.unit_ids[~mask_keep_ids]
158+
to_remove_label_mask = np.isin(peak_labels, to_remove_ids)
159+
peak_labels[to_remove_label_mask] = -1
160+
templates = cleaned_templates
161+
new_sparse_mask = templates.sparsity.mask.copy()
162+
templates = templates.to_dense()
131163
labels = templates.unit_ids
132164

133165
if verbose:
@@ -154,6 +186,21 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
154186
is_in_uV=False,
155187
)
156188

189+
# clean very small cluster before peeler
190+
if (
191+
params["clean_low_firing"]["subsampling_factor"] is not None
192+
and params["clean_low_firing"]["min_firing_rate"] is not None
193+
):
194+
peak_labels, to_keep = remove_small_cluster(
195+
recording,
196+
peaks,
197+
peak_labels,
198+
min_firing_rate=params["clean_low_firing"]["min_firing_rate"],
199+
subsampling_factor=params["clean_low_firing"]["subsampling_factor"],
200+
verbose=verbose,
201+
)
202+
templates = templates.select_units(to_keep)
203+
157204
labels = templates.unit_ids
158205

159206
if debug_folder is not None:

0 commit comments

Comments
 (0)