Skip to content

Commit 0baa85a

Browse files
committed
code formatting
1 parent 97f3ab4 commit 0baa85a

2 files changed

Lines changed: 17 additions & 35 deletions

File tree

src/spikeinterface/core/sortinganalyzer.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -170,27 +170,22 @@ def create_sorting_analyzer(
170170
**sparsity_kwargs,
171171
)
172172

173-
if format != "memory":
174-
if format == "zarr":
175-
if not is_path_remote(folder):
176-
folder = clean_zarr_folder_name(folder)
177-
if not is_path_remote(folder):
178-
if Path(folder).is_dir():
179-
if not overwrite:
180-
raise ValueError(f"Folder already exists {folder}! Use overwrite=True to overwrite it.")
181-
else:
182-
shutil.rmtree(folder)
173+
if format != "memory" and not is_path_remote(folder):
174+
folder = clean_zarr_folder_name(folder) if format == "zarr" else folder
175+
if Path(folder).is_dir():
176+
if overwrite:
177+
shutil.rmtree(folder)
178+
else:
179+
raise ValueError(f"Folder {folder} already exists! Use overwrite=True to overwrite it.")
183180

184181
# handle sparsity
185182
if sparsity is not None:
186183
# some checks
187184
assert isinstance(sparsity, ChannelSparsity), "'sparsity' must be a ChannelSparsity object"
188-
assert np.array_equal(
189-
sorting.unit_ids, sparsity.unit_ids
190-
), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond"
191-
assert np.array_equal(
192-
recording.channel_ids, sparsity.channel_ids
193-
), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond"
185+
error_msg = "If external sparsity is given, unit_ids must match sorting"
186+
assert np.array_equal(sorting.unit_ids, sparsity.unit_ids), error_msg
187+
error_msg = "If external sparsity is given, channel_ids must match recording"
188+
assert np.array_equal(recording.channel_ids, sparsity.channel_ids), error_msg
194189
elif sparse:
195190
sparsity = estimate_sparsity(sorting, recording, **sparsity_kwargs)
196191
else:

src/spikeinterface/curation/remove_redundant.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def remove_redundant_units(
7272
sorting = sorting_or_sorting_analyzer.sorting
7373
sorting_analyzer = sorting_or_sorting_analyzer
7474
else:
75-
assert not align, "The 'align' option is only available when a SortingAnalyzer is used as input"
7675
# other remove strategies rely on sorting analyzer looking at templates
7776
assert remove_strategy == "max_spikes", "For a Sorting input the remove_strategy must be 'max_spikes'"
7877
sorting = sorting_or_sorting_analyzer
@@ -108,25 +107,15 @@ def remove_redundant_units(
108107
remove_unit_ids.append(u1)
109108
elif np.abs(unit_peak_shifts[u1]) < np.abs(unit_peak_shifts[u2]):
110109
remove_unit_ids.append(u2)
111-
else:
112-
# equal shift use peak values
113-
if np.abs(peak_values[u1]) < np.abs(peak_values[u2]):
114-
remove_unit_ids.append(u1)
115-
else:
116-
remove_unit_ids.append(u2)
110+
else: # equal shift use peak values
111+
remove_unit_ids.append(u1 if peak_values[u1] < peak_values[u2] else u2)
117112
elif remove_strategy == "highest_amplitude":
118113
for u1, u2 in redundant_unit_pairs:
119-
if np.abs(peak_values[u1]) < np.abs(peak_values[u2]):
120-
remove_unit_ids.append(u1)
121-
else:
122-
remove_unit_ids.append(u2)
114+
remove_unit_ids.append(u1 if peak_values[u1] < peak_values[u2] else u2)
123115
elif remove_strategy == "max_spikes":
124116
num_spikes = sorting.count_num_spikes_per_unit(outputs="dict")
125117
for u1, u2 in redundant_unit_pairs:
126-
if num_spikes[u1] < num_spikes[u2]:
127-
remove_unit_ids.append(u1)
128-
else:
129-
remove_unit_ids.append(u2)
118+
remove_unit_ids.append(u1 if num_spikes[u1] < num_spikes[u2] else u2)
130119
elif remove_strategy == "with_metrics":
131120
# TODO
132121
# @aurelien @alessio
@@ -162,9 +151,7 @@ def find_redundant_units(sorting, delta_time: float = 0.4, agreement_threshold=0
162151
163152
Returns
164153
-------
165-
list
166-
The list of duplicate units
167-
list of 2-element lists
154+
list of 2-element tuples
168155
The list of duplicate pairs
169156
"""
170157
from spikeinterface.comparison import compare_two_sorters
@@ -186,6 +173,6 @@ def find_redundant_units(sorting, delta_time: float = 0.4, agreement_threshold=0
186173
event_counts = comparison.event_counts1
187174
shared = max(n_coincidents / event_counts[unit_i], n_coincidents / event_counts[unit_j])
188175
if shared > duplicate_threshold:
189-
redundant_unit_pairs.append([unit_i, unit_j])
176+
redundant_unit_pairs.append((unit_i, unit_j))
190177

191178
return redundant_unit_pairs

0 commit comments

Comments
 (0)