@@ -39,7 +39,8 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
3939 "motion_correction" : {"preset" : "dredge_fast" },
4040 "merging" : {"max_distance_um" : 50 },
4141 "clustering" : {"method" : "iterative-hdbscan" , "method_kwargs" : dict ()},
42- "cleaning" : {"min_snr" : 5 , "max_jitter_ms" : 0.1 , "sparsify_threshold" : None },
42+ "cleaning" : {"min_snr" : 5 , "max_jitter_ms" : 0.2 , "sparsify_threshold" : 1 , "mean_sd_ratio_threshold" : 3 },
43+ "min_firing_rate" : 0.1 ,
4344 "matching" : {"method" : "circus-omp" , "method_kwargs" : dict (), "pipeline_kwargs" : dict ()},
4445 "apply_preprocessing" : True ,
4546 "apply_whitening" : True ,
@@ -103,6 +104,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
103104 from spikeinterface .sortingcomponents .peak_detection import detect_peaks
104105 from spikeinterface .sortingcomponents .peak_selection import select_peaks
105106 from spikeinterface .sortingcomponents .clustering import find_clusters_from_peaks
107+ from spikeinterface .sortingcomponents .clustering .tools import remove_small_cluster
106108 from spikeinterface .sortingcomponents .matching import find_spikes_from_templates
107109 from spikeinterface .sortingcomponents .tools import check_probe_for_drift_correction
108110 from spikeinterface .sortingcomponents .tools import clean_templates
@@ -118,8 +120,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
118120 ms_before = params ["general" ].get ("ms_before" , 0.5 )
119121 ms_after = params ["general" ].get ("ms_after" , 1.5 )
120122 radius_um = params ["general" ].get ("radius_um" , 100.0 )
121- detect_threshold = params ["detection" ]["method_kwargs" ].get ("detect_threshold" , 5 )
122- peak_sign = params ["detection" ].get ("peak_sign" , "neg" )
123123 deterministic = params ["deterministic_peaks_detection" ]
124124 debug = params ["debug" ]
125125 seed = params ["seed" ]
@@ -310,6 +310,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
310310 if verbose :
311311 print ("Kept %d peaks for clustering" % len (selected_peaks ))
312312
313+ cleaning_kwargs = params .get ("cleaning" , {}).copy ()
314+ cleaning_kwargs ["remove_empty" ] = True
315+
313316 if clustering_method in [
314317 "iterative-hdbscan" ,
315318 "iterative-isosplit" ,
@@ -319,6 +322,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
319322 clustering_params .update (verbose = verbose )
320323 clustering_params .update (seed = seed )
321324 clustering_params .update (peaks_svd = params ["general" ])
325+ if clustering_method in ["iterative-hdbscan" , "iterative-isosplit" ]:
326+ clustering_params .update (clean_templates = cleaning_kwargs )
327+ clustering_params ["noise_levels" ] = noise_levels
328+
322329 if debug :
323330 clustering_params ["debug_folder" ] = sorter_output_folder / "clustering"
324331
@@ -328,6 +335,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
328335 method = clustering_method ,
329336 method_kwargs = clustering_params ,
330337 extra_outputs = True ,
338+ verbose = verbose ,
331339 job_kwargs = job_kwargs ,
332340 )
333341
@@ -365,7 +373,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
365373 else :
366374 from spikeinterface .sortingcomponents .clustering .tools import get_templates_from_peaks_and_svd
367375
368- dense_templates , new_sparse_mask = get_templates_from_peaks_and_svd (
376+ dense_templates , new_sparse_mask , max_std_per_channel = get_templates_from_peaks_and_svd (
369377 recording_w ,
370378 selected_peaks ,
371379 peak_labels ,
@@ -375,16 +383,30 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
375383 more_outs ["peaks_svd" ],
376384 more_outs ["peak_svd_sparse_mask" ],
377385 operator = "median" ,
386+ return_max_std_per_channel = True ,
378387 )
379388 # this release the peak_svd memmap file
380389 templates = dense_templates .to_sparse (new_sparse_mask )
381390
382391 del more_outs
383392
384- cleaning_kwargs = params .get ("cleaning" , {}).copy ()
385- cleaning_kwargs ["noise_levels" ] = noise_levels
386- cleaning_kwargs ["remove_empty" ] = True
387- templates = clean_templates (templates , ** cleaning_kwargs )
393+ before_clean_ids = templates .unit_ids .copy ()
394+ cleaning_kwargs ["max_std_per_channel" ] = max_std_per_channel
395+ cleaning_kwargs ["verbose" ] = verbose
396+ templates = clean_templates (templates , noise_levels = noise_levels , ** cleaning_kwargs )
397+ remove_peak_mask = ~ np .isin (peak_labels , templates .unit_ids )
398+ peak_labels [remove_peak_mask ] = - 1
399+
400+ if params ["min_firing_rate" ] is not None :
401+ peak_labels , to_keep = remove_small_cluster (
402+ recording_w ,
403+ selected_peaks ,
404+ peak_labels ,
405+ min_firing_rate = params ["min_firing_rate" ],
406+ subsampling_factor = peaks .size / selected_peaks .size ,
407+ verbose = verbose ,
408+ )
409+ templates = templates .select_units (to_keep )
388410
389411 if verbose :
390412 print ("Kept %d clean clusters" % len (templates .unit_ids ))
@@ -508,5 +530,4 @@ def final_cleaning_circus(
508530 sparsity_overlap = sparsity_overlap ,
509531 ** job_kwargs ,
510532 )
511-
512533 return final_sa
0 commit comments