1010
1111
1212def remove_redundant_units (
13- sorting_or_sorting_analyzer ,
14- align = True ,
15- unit_peak_shifts = None ,
16- delta_time = 0.4 ,
17- agreement_threshold = 0.2 ,
18- duplicate_threshold = 0.8 ,
19- remove_strategy = "minimum_shift" ,
20- peak_sign = "neg" ,
21- extra_outputs = False ,
22- ) -> BaseSorting :
13+ sorting_or_sorting_analyzer : BaseSorting | SortingAnalyzer ,
14+ align : bool = True ,
15+ unit_peak_shifts : dict [ int , float ] | None = None ,
16+ delta_time : float = 0.4 ,
17+ agreement_threshold : float = 0.2 ,
18+ duplicate_threshold : float = 0.8 ,
19+ remove_strategy : str = "minimum_shift" ,
20+ peak_sign : str = "neg" ,
21+ extra_outputs : bool = False ,
22+ ) -> BaseSorting | tuple [ BaseSorting , list [ tuple [ int , int ]]] :
2323 """
2424 Removes redundant or duplicate units by comparing the sorting output with itself.
2525
@@ -72,15 +72,14 @@ def remove_redundant_units(
7272 sorting = sorting_or_sorting_analyzer .sorting
7373 sorting_analyzer = sorting_or_sorting_analyzer
7474 else :
75- assert not align , "The 'align' option is only available when a SortingAnalyzer is used as input"
7675 # other remove strategies rely on sorting analyzer looking at templates
7776 assert remove_strategy == "max_spikes" , "For a Sorting input the remove_strategy must be 'max_spikes'"
7877 sorting = sorting_or_sorting_analyzer
7978 sorting_analyzer = None
8079
8180 if align and unit_peak_shifts is None :
8281 assert sorting_analyzer is not None , "For align=True must give a SortingAnalyzer or explicit unit_peak_shifts"
83- unit_peak_shifts = get_template_extremum_channel_peak_shift (sorting_analyzer )
82+ unit_peak_shifts = get_template_extremum_channel_peak_shift (sorting_analyzer , peak_sign = peak_sign )
8483
8584 if align :
8685 sorting_aligned = align_sorting (sorting , unit_peak_shifts )
@@ -108,25 +107,15 @@ def remove_redundant_units(
108107 remove_unit_ids .append (u1 )
109108 elif np .abs (unit_peak_shifts [u1 ]) < np .abs (unit_peak_shifts [u2 ]):
110109 remove_unit_ids .append (u2 )
111- else :
112- # equal shift use peak values
113- if np .abs (peak_values [u1 ]) < np .abs (peak_values [u2 ]):
114- remove_unit_ids .append (u1 )
115- else :
116- remove_unit_ids .append (u2 )
110+ else : # equal shift use peak values
111+ remove_unit_ids .append (u1 if peak_values [u1 ] < peak_values [u2 ] else u2 )
117112 elif remove_strategy == "highest_amplitude" :
118113 for u1 , u2 in redundant_unit_pairs :
119- if np .abs (peak_values [u1 ]) < np .abs (peak_values [u2 ]):
120- remove_unit_ids .append (u1 )
121- else :
122- remove_unit_ids .append (u2 )
114+ remove_unit_ids .append (u1 if peak_values [u1 ] < peak_values [u2 ] else u2 )
123115 elif remove_strategy == "max_spikes" :
124116 num_spikes = sorting .count_num_spikes_per_unit (outputs = "dict" )
125117 for u1 , u2 in redundant_unit_pairs :
126- if num_spikes [u1 ] < num_spikes [u2 ]:
127- remove_unit_ids .append (u1 )
128- else :
129- remove_unit_ids .append (u2 )
118+ remove_unit_ids .append (u1 if num_spikes [u1 ] < num_spikes [u2 ] else u2 )
130119 elif remove_strategy == "with_metrics" :
131120 # TODO
132121 # @aurelien @alessio
@@ -144,7 +133,9 @@ def remove_redundant_units(
144133 return sorting_clean
145134
146135
147- def find_redundant_units (sorting , delta_time : float = 0.4 , agreement_threshold = 0.2 , duplicate_threshold = 0.8 ):
136+ def find_redundant_units (
137+ sorting : BaseSorting , delta_time : float = 0.4 , agreement_threshold : float = 0.2 , duplicate_threshold : float = 0.8
138+ ) -> list [tuple [int , int ]]:
148139 """
149140 Finds redundant or duplicate units by comparing the sorting output with itself.
150141
@@ -162,9 +153,7 @@ def find_redundant_units(sorting, delta_time: float = 0.4, agreement_threshold=0
162153
163154 Returns
164155 -------
165- list
166- The list of duplicate units
167- list of 2-element lists
156+ list of 2-element tuples
168157 The list of duplicate pairs
169158 """
170159 from spikeinterface .comparison import compare_two_sorters
@@ -186,6 +175,6 @@ def find_redundant_units(sorting, delta_time: float = 0.4, agreement_threshold=0
186175 event_counts = comparison .event_counts1
187176 shared = max (n_coincidents / event_counts [unit_i ], n_coincidents / event_counts [unit_j ])
188177 if shared > duplicate_threshold :
189- redundant_unit_pairs .append ([ unit_i , unit_j ] )
178+ redundant_unit_pairs .append (( unit_i , unit_j ) )
190179
191180 return redundant_unit_pairs
0 commit comments