Skip to content

Commit a44b28a

Browse files
ecobostpre-commit-ci[bot]alejoe91
authored
Send peak_sign to get_template_extremum_channel_peak_shift in remove_redundant_units (#4408)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alessio Buccino <alejoe9187@gmail.com>
1 parent 5fb3668 commit a44b28a

2 files changed

Lines changed: 31 additions & 47 deletions

File tree

src/spikeinterface/core/sortinganalyzer.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -170,27 +170,22 @@ def create_sorting_analyzer(
170170
**sparsity_kwargs,
171171
)
172172

173-
if format != "memory":
174-
if format == "zarr":
175-
if not is_path_remote(folder):
176-
folder = clean_zarr_folder_name(folder)
177-
if not is_path_remote(folder):
178-
if Path(folder).is_dir():
179-
if not overwrite:
180-
raise ValueError(f"Folder already exists {folder}! Use overwrite=True to overwrite it.")
181-
else:
182-
shutil.rmtree(folder)
173+
if format != "memory" and not is_path_remote(folder):
174+
folder = clean_zarr_folder_name(folder) if format == "zarr" else folder
175+
if Path(folder).is_dir():
176+
if overwrite:
177+
shutil.rmtree(folder)
178+
else:
179+
raise ValueError(f"Folder {folder} already exists! Use overwrite=True to overwrite it.")
183180

184181
# handle sparsity
185182
if sparsity is not None:
186183
# some checks
187184
assert isinstance(sparsity, ChannelSparsity), "'sparsity' must be a ChannelSparsity object"
188-
assert np.array_equal(
189-
sorting.unit_ids, sparsity.unit_ids
190-
), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond"
191-
assert np.array_equal(
192-
recording.channel_ids, sparsity.channel_ids
193-
), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond"
185+
error_msg = "If external sparsity is given, unit_ids must match sorting"
186+
assert np.array_equal(sorting.unit_ids, sparsity.unit_ids), error_msg
187+
error_msg = "If external sparsity is given, channel_ids must match recording"
188+
assert np.array_equal(recording.channel_ids, sparsity.channel_ids), error_msg
194189
elif sparse:
195190
sparsity = estimate_sparsity(sorting, recording, **sparsity_kwargs)
196191
else:

src/spikeinterface/curation/remove_redundant.py

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,16 @@
1010

1111

1212
def remove_redundant_units(
13-
sorting_or_sorting_analyzer,
14-
align=True,
15-
unit_peak_shifts=None,
16-
delta_time=0.4,
17-
agreement_threshold=0.2,
18-
duplicate_threshold=0.8,
19-
remove_strategy="minimum_shift",
20-
peak_sign="neg",
21-
extra_outputs=False,
22-
) -> BaseSorting:
13+
sorting_or_sorting_analyzer: BaseSorting | SortingAnalyzer,
14+
align: bool = True,
15+
unit_peak_shifts: dict[int, float] | None = None,
16+
delta_time: float = 0.4,
17+
agreement_threshold: float = 0.2,
18+
duplicate_threshold: float = 0.8,
19+
remove_strategy: str = "minimum_shift",
20+
peak_sign: str = "neg",
21+
extra_outputs: bool = False,
22+
) -> BaseSorting | tuple[BaseSorting, list[tuple[int, int]]]:
2323
"""
2424
Removes redundant or duplicate units by comparing the sorting output with itself.
2525
@@ -72,15 +72,14 @@ def remove_redundant_units(
7272
sorting = sorting_or_sorting_analyzer.sorting
7373
sorting_analyzer = sorting_or_sorting_analyzer
7474
else:
75-
assert not align, "The 'align' option is only available when a SortingAnalyzer is used as input"
7675
# other remove strategies rely on sorting analyzer looking at templates
7776
assert remove_strategy == "max_spikes", "For a Sorting input the remove_strategy must be 'max_spikes'"
7877
sorting = sorting_or_sorting_analyzer
7978
sorting_analyzer = None
8079

8180
if align and unit_peak_shifts is None:
8281
assert sorting_analyzer is not None, "For align=True must give a SortingAnalyzer or explicit unit_peak_shifts"
83-
unit_peak_shifts = get_template_extremum_channel_peak_shift(sorting_analyzer)
82+
unit_peak_shifts = get_template_extremum_channel_peak_shift(sorting_analyzer, peak_sign=peak_sign)
8483

8584
if align:
8685
sorting_aligned = align_sorting(sorting, unit_peak_shifts)
@@ -108,25 +107,15 @@ def remove_redundant_units(
108107
remove_unit_ids.append(u1)
109108
elif np.abs(unit_peak_shifts[u1]) < np.abs(unit_peak_shifts[u2]):
110109
remove_unit_ids.append(u2)
111-
else:
112-
# equal shift use peak values
113-
if np.abs(peak_values[u1]) < np.abs(peak_values[u2]):
114-
remove_unit_ids.append(u1)
115-
else:
116-
remove_unit_ids.append(u2)
110+
else: # equal shift use peak values
111+
remove_unit_ids.append(u1 if peak_values[u1] < peak_values[u2] else u2)
117112
elif remove_strategy == "highest_amplitude":
118113
for u1, u2 in redundant_unit_pairs:
119-
if np.abs(peak_values[u1]) < np.abs(peak_values[u2]):
120-
remove_unit_ids.append(u1)
121-
else:
122-
remove_unit_ids.append(u2)
114+
remove_unit_ids.append(u1 if peak_values[u1] < peak_values[u2] else u2)
123115
elif remove_strategy == "max_spikes":
124116
num_spikes = sorting.count_num_spikes_per_unit(outputs="dict")
125117
for u1, u2 in redundant_unit_pairs:
126-
if num_spikes[u1] < num_spikes[u2]:
127-
remove_unit_ids.append(u1)
128-
else:
129-
remove_unit_ids.append(u2)
118+
remove_unit_ids.append(u1 if num_spikes[u1] < num_spikes[u2] else u2)
130119
elif remove_strategy == "with_metrics":
131120
# TODO
132121
# @aurelien @alessio
@@ -144,7 +133,9 @@ def remove_redundant_units(
144133
return sorting_clean
145134

146135

147-
def find_redundant_units(sorting, delta_time: float = 0.4, agreement_threshold=0.2, duplicate_threshold=0.8):
136+
def find_redundant_units(
137+
sorting: BaseSorting, delta_time: float = 0.4, agreement_threshold: float = 0.2, duplicate_threshold: float = 0.8
138+
) -> list[tuple[int, int]]:
148139
"""
149140
Finds redundant or duplicate units by comparing the sorting output with itself.
150141
@@ -162,9 +153,7 @@ def find_redundant_units(sorting, delta_time: float = 0.4, agreement_threshold=0
162153
163154
Returns
164155
-------
165-
list
166-
The list of duplicate units
167-
list of 2-element lists
156+
list of 2-element tuples
168157
The list of duplicate pairs
169158
"""
170159
from spikeinterface.comparison import compare_two_sorters
@@ -186,6 +175,6 @@ def find_redundant_units(sorting, delta_time: float = 0.4, agreement_threshold=0
186175
event_counts = comparison.event_counts1
187176
shared = max(n_coincidents / event_counts[unit_i], n_coincidents / event_counts[unit_j])
188177
if shared > duplicate_threshold:
189-
redundant_unit_pairs.append([unit_i, unit_j])
178+
redundant_unit_pairs.append((unit_i, unit_j))
190179

191180
return redundant_unit_pairs

0 commit comments

Comments
 (0)