Skip to content

Commit 4d892e9

Browse files
authored
Merge pull request #4138 from alejoe91/fix-curation-api-doc
Improve docs for curation module and model
2 parents 3e4fcb4 + 5e4ce6e commit 4d892e9

12 files changed

Lines changed: 99 additions & 101 deletions

File tree

doc/api.rst

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,21 +356,36 @@ spikeinterface.curation
356356
.. automodule:: spikeinterface.curation
357357

358358
.. autofunction:: apply_curation
359-
.. autofunction:: get_potential_auto_merge
359+
.. autofunction:: compute_merge_unit_groups
360360
.. autofunction:: find_redundant_units
361361
.. autofunction:: remove_redundant_units
362362
.. autofunction:: remove_duplicated_spikes
363363
.. autofunction:: remove_excess_spikes
364-
.. autofunction:: load_model
365364
.. autofunction:: auto_label_units
365+
.. autofunction:: load_model
366366
.. autofunction:: train_model
367367

368+
Curation Model
369+
~~~~~~~~~~~~~~
370+
371+
This section describes the ``pydantic`` curation model classes used to represent and manage curation actions
372+
such as merging and splitting units, as well as defining labels for units.
373+
374+
.. automodule:: spikeinterface.curation.curation_model
375+
376+
.. autopydantic_model:: CurationModel
377+
.. autopydantic_model:: Merge
378+
.. autopydantic_model:: Split
379+
.. autopydantic_model:: ManualLabel
380+
.. autopydantic_model:: LabelDefinition
381+
368382
Deprecated
369383
~~~~~~~~~~
370384
.. automodule:: spikeinterface.curation
371385
:noindex:
372386

373387
.. autofunction:: apply_sortingview_curation
388+
.. autofunction:: get_potential_auto_merge
374389
.. autoclass:: CurationSorting
375390
.. autoclass:: MergeUnitsSorting
376391
.. autoclass:: SplitUnitSorting

doc/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
'sphinx.ext.autosummary',
6464
'sphinx_gallery.gen_gallery',
6565
'numpydoc',
66+
'sphinxcontrib.autodoc_pydantic',
6667
'sphinx.ext.autosectionlabel',
6768
'sphinx_design',
6869
'sphinxcontrib.jquery',
@@ -76,6 +77,8 @@
7677

7778
numpydoc_show_class_members = False
7879

80+
autodoc_pydantic_model_show_json = True
81+
autodoc_pydantic_model_show_config_summary = False
7982

8083
# Add any paths that contain templates here, relative to this directory.
8184
templates_path = ['_templates']

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,13 @@ test = [
194194

195195
docs = [
196196
"Sphinx",
197+
"ipython",
197198
"sphinx_rtd_theme>=1.2",
198199
"sphinx-gallery",
199200
"sphinx-design",
200201
"numpydoc",
201-
"ipython",
202202
"sphinxcontrib-jquery",
203+
"autodoc_pydantic",
203204

204205
# for notebooks in the gallery
205206
"MEArec", # Use as an example

src/spikeinterface/core/sorting_tools.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -365,16 +365,13 @@ def set_properties_after_merging(
365365

366366
for key in prop_keys:
367367
parent_values = sorting_pre_merge.get_property(key)
368-
if parent_values.dtype.kind not in default_missing_values:
369-
# if the property is boolean or integer there is no missing values so we skip
370-
# for instance recursive "is_merged" will not be propagated
371-
continue
372368

373369
# propagate keep values
374370
shape = (len(sorting_post_merge.unit_ids),) + parent_values.shape[1:]
375371
new_values = np.empty(shape=shape, dtype=parent_values.dtype)
376372
new_values[keep_post_inds] = parent_values[keep_pre_inds]
377373

374+
skip_property = False
378375
for new_id, merge_group in zip(new_unit_ids, merge_unit_groups):
379376
merged_indices = sorting_pre_merge.ids_to_indices(merge_group)
380377
merge_values = parent_values[merged_indices]
@@ -384,9 +381,15 @@ def set_properties_after_merging(
384381
# and new values only if they are all similar
385382
new_values[new_index] = merge_values[0]
386383
else:
387-
388-
new_values[new_index] = default_missing_values[parent_values.dtype.kind]
389-
sorting_post_merge.set_property(key, new_values)
384+
if parent_values.dtype.kind not in default_missing_values:
385+
# if the property doesn't have a default missing value and it is not the same
386+
# for all merged units, we skip it
387+
skip_property = True
388+
break
389+
else:
390+
new_values[new_index] = default_missing_values[parent_values.dtype.kind]
391+
if not skip_property:
392+
sorting_post_merge.set_property(key, new_values)
390393

391394
# set is_merged property
392395
is_merged = np.ones(len(sorting_post_merge.unit_ids), dtype=bool)

src/spikeinterface/core/sortinganalyzer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,10 @@ def _save_or_select_or_merge_or_split(
10721072
if sorting_provenance is None:
10731073
# if the original sorting object is not available anymore (kilosort folder deleted, ....), take the copy
10741074
sorting_provenance = self.sorting
1075+
# add in-memory properties added to the analyzer
1076+
for key in self.sorting.get_property_keys():
1077+
if key not in sorting_provenance.get_property_keys():
1078+
sorting_provenance.set_property(key, self.sorting.get_property(key))
10751079

10761080
if merge_unit_groups is None and split_units is None:
10771081
# when only some unit_ids then the sorting must be sliced

src/spikeinterface/curation/auto_merge.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def compute_merge_unit_groups(
144144
* | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN.
145145
| It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations",
146146
| "knn", "quality_score"
147+
147148
If `preset` is None, you can specify the steps manually with the `steps` parameter.
148149
resolve_graph : bool, default: True
149150
If True, the function resolves the potential unit pairs to be merged into multiple-unit merges.

src/spikeinterface/curation/curation_format.py

Lines changed: 19 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -125,60 +125,14 @@ def apply_curation_labels(
125125
# Please note that manual_labels is done on the unit_ids before the merge!!!
126126
manual_labels = curation_label_to_vectors(curation_model)
127127

128-
# apply on non merged / split
129-
merge_new_unit_ids = [m.new_unit_id for m in curation_model.merges]
130-
split_new_unit_ids = [m.new_unit_ids for m in curation_model.splits]
131-
split_new_unit_ids = list(chain(*split_new_unit_ids))
132-
133-
merged_split_units = merge_new_unit_ids + split_new_unit_ids
134128
for key, values in manual_labels.items():
135129
all_values = np.zeros(sorting.unit_ids.size, dtype=values.dtype)
136130
for unit_ind, unit_id in enumerate(sorting.unit_ids):
137-
if unit_id not in merged_split_units:
138-
ind = list(curation_model.unit_ids).index(unit_id)
139-
all_values[unit_ind] = values[ind]
131+
# if unit_id not in merged_split_units:
132+
ind = list(curation_model.unit_ids).index(unit_id)
133+
all_values[unit_ind] = values[ind]
140134
sorting.set_property(key, all_values)
141135

142-
for new_unit_id, merge in zip(merge_new_unit_ids, curation_model.merges):
143-
old_group_ids = merge.unit_ids
144-
for label_key, label_def in curation_model.label_definitions.items():
145-
if label_def.exclusive:
146-
group_values = []
147-
for unit_id in old_group_ids:
148-
ind = list(curation_model.unit_ids).index(unit_id)
149-
value = manual_labels[label_key][ind]
150-
if value != "":
151-
group_values.append(value)
152-
if len(set(group_values)) == 1:
153-
# all group has the same label or empty
154-
sorting.set_property(key, values=group_values[:1], ids=[new_unit_id])
155-
else:
156-
for key in label_def.label_options:
157-
group_values = []
158-
for unit_id in old_group_ids:
159-
ind = list(curation_model.unit_ids).index(unit_id)
160-
value = manual_labels[key][ind]
161-
group_values.append(value)
162-
new_value = np.any(group_values)
163-
sorting.set_property(key, values=[new_value], ids=[new_unit_id])
164-
165-
# splits
166-
for split in curation_model.splits:
167-
# propagate property of splut unit to new units
168-
old_unit = split.unit_id
169-
new_unit_ids = split.new_unit_ids
170-
for label_key, label_def in curation_model.label_definitions.items():
171-
if label_def.exclusive:
172-
ind = list(curation_model.unit_ids).index(old_unit)
173-
value = manual_labels[label_key][ind]
174-
if value != "":
175-
sorting.set_property(label_key, values=[value] * len(new_unit_ids), ids=new_unit_ids)
176-
else:
177-
for key in label_def.label_options:
178-
ind = list(curation_model.unit_ids).index(old_unit)
179-
value = manual_labels[key][ind]
180-
sorting.set_property(key, values=[value] * len(new_unit_ids), ids=new_unit_ids)
181-
182136

183137
def apply_curation(
184138
sorting_or_analyzer: BaseSorting | SortingAnalyzer,
@@ -194,10 +148,11 @@ def apply_curation(
194148
Apply curation dict to a Sorting or a SortingAnalyzer.
195149
196150
Steps are done in this order:
197-
1. Apply removal using curation_dict["removed"]
198-
2. Apply merges using curation_dict["merges"]
199-
3. Apply splits using curation_dict["splits"]
200-
4. Set labels using curation_dict["manual_labels"]
151+
152+
1. Apply labels using curation_dict["manual_labels"]
153+
2. Apply removal using curation_dict["removed"]
154+
3. Apply merges using curation_dict["merges"]
155+
4. Apply splits using curation_dict["splits"]
201156
202157
A new Sorting or SortingAnalyzer (in memory) is returned.
203158
The user (an adult) has the responsability to save it somewhere (or not).
@@ -243,33 +198,36 @@ def apply_curation(
243198
if isinstance(curation_dict_or_model, dict):
244199
curation_model = CurationModel(**curation_dict_or_model)
245200
else:
246-
curation_model = curation_dict_or_model
201+
curation_model = curation_dict_or_model.model_copy(deep=True)
247202

248203
if not np.array_equal(np.asarray(curation_model.unit_ids), sorting_or_analyzer.unit_ids):
249204
raise ValueError("unit_ids from the curation_dict do not match the one from Sorting or SortingAnalyzer")
250205

251-
# 1. Remove units
206+
# 1. Apply labels
207+
apply_curation_labels(sorting_or_analyzer, curation_model)
208+
209+
# 2. Remove units
252210
if len(curation_model.removed) > 0:
253211
curated_sorting_or_analyzer = sorting_or_analyzer.remove_units(curation_model.removed)
254212
else:
255213
curated_sorting_or_analyzer = sorting_or_analyzer
256214

257-
# 2. Merge units
215+
# 3. Merge units
258216
if len(curation_model.merges) > 0:
259217
merge_unit_groups = [m.unit_ids for m in curation_model.merges]
260218
merge_new_unit_ids = [m.new_unit_id for m in curation_model.merges if m.new_unit_id is not None]
261219
if len(merge_new_unit_ids) == 0:
262220
merge_new_unit_ids = None
263221
if isinstance(sorting_or_analyzer, BaseSorting):
264-
curated_sorting_or_analyzer, _, new_unit_ids = apply_merges_to_sorting(
222+
curated_sorting_or_analyzer, _, _ = apply_merges_to_sorting(
265223
curated_sorting_or_analyzer,
266224
merge_unit_groups=merge_unit_groups,
267225
censor_ms=censor_ms,
268226
new_id_strategy=new_id_strategy,
269227
return_extra=True,
270228
)
271229
else:
272-
curated_sorting_or_analyzer, new_unit_ids = curated_sorting_or_analyzer.merge_units(
230+
curated_sorting_or_analyzer, _ = curated_sorting_or_analyzer.merge_units(
273231
merge_unit_groups=merge_unit_groups,
274232
censor_ms=censor_ms,
275233
merging_mode=merging_mode,
@@ -280,10 +238,8 @@ def apply_curation(
280238
verbose=verbose,
281239
**job_kwargs,
282240
)
283-
for i, merge_unit_id in enumerate(new_unit_ids):
284-
curation_model.merges[i].new_unit_id = merge_unit_id
285241

286-
# 3. Split units
242+
# 4. Split units
287243
if len(curation_model.splits) > 0:
288244
split_units = {}
289245
for split in curation_model.splits:
@@ -297,26 +253,21 @@ def apply_curation(
297253
if len(split_new_unit_ids) == 0:
298254
split_new_unit_ids = None
299255
if isinstance(sorting_or_analyzer, BaseSorting):
300-
curated_sorting_or_analyzer, new_unit_ids = apply_splits_to_sorting(
256+
curated_sorting_or_analyzer, _ = apply_splits_to_sorting(
301257
curated_sorting_or_analyzer,
302258
split_units,
303259
new_unit_ids=split_new_unit_ids,
304260
new_id_strategy=new_id_strategy,
305261
return_extra=True,
306262
)
307263
else:
308-
curated_sorting_or_analyzer, new_unit_ids = curated_sorting_or_analyzer.split_units(
264+
curated_sorting_or_analyzer, _ = curated_sorting_or_analyzer.split_units(
309265
split_units,
310266
new_id_strategy=new_id_strategy,
311267
return_new_unit_ids=True,
312268
new_unit_ids=split_new_unit_ids,
313269
format="memory",
314270
verbose=verbose,
315271
)
316-
for i, split_unit_ids in enumerate(new_unit_ids):
317-
curation_model.splits[i].new_unit_ids = split_unit_ids
318-
319-
# 4. Apply labels
320-
apply_curation_labels(curated_sorting_or_analyzer, curation_model)
321272

322273
return curated_sorting_or_analyzer

src/spikeinterface/curation/curation_model.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,13 @@ def add_label_definition_name(cls, label_definitions):
104104

105105
@classmethod
106106
def check_manual_labels(cls, values):
107+
"""
108+
Checks and validates the manual labels in the curation model.
109+
110+
* Checks if the unit_ids in each manual label exist in the unit_ids list.
111+
* Validates that each label in the manual labels exists in the label_definitions.
112+
113+
"""
107114
unit_ids = list(values["unit_ids"])
108115
manual_labels = values.get("manual_labels")
109116
if manual_labels is None:
@@ -135,6 +142,15 @@ def check_manual_labels(cls, values):
135142

136143
@classmethod
137144
def check_merges(cls, values):
145+
"""
146+
Checks and validates the merges in the curation model.
147+
148+
* Checks if the unit_ids in each merge group exist in the unit_ids list.
149+
* Validates that each merge group has at least two unit IDs.
150+
* Ensures that any new_unit_id provided does not already exist in the unit_ids list.
151+
* Converts merges from dict format to list of Merge objects if necessary.
152+
153+
"""
138154
unit_ids = list(values["unit_ids"])
139155
merges = values.get("merges")
140156
if merges is None:
@@ -184,15 +200,14 @@ def check_merges(cls, values):
184200
def check_splits(cls, values):
185201
"""
186202
Checks and validates the splits in the curation model.
187-
If `splits` is a dictionary with unit_id as key and split indices as values,
188-
it converts it to a list of Split objects.
189-
Each Split object is then validated:
190-
- Checks if the unit_id exists in the unit_ids list.
191-
- Validates the mode (indices or labels).
192-
- If mode is indices, checks that indices are defined and not empty, and that there are no duplicate indices.
193-
- If mode is labels, checks that labels are defined and not empty.
194-
- Validates new unit IDs if provided, ensuring they are not already in the unit_ids list and match the
195-
number of splits.
203+
204+
* Checks if the unit_id exists in the unit_ids list.
205+
* Validates the mode (indices or labels).
206+
* If mode is indices, checks that indices are defined and not empty, and that there are no duplicate indices.
207+
* If mode is labels, checks that labels are defined and not empty.
208+
* | Validates new unit IDs if provided, ensuring they are not already in the unit_ids list and match the
209+
| number of splits.
210+
196211
"""
197212
unit_ids = list(values["unit_ids"])
198213
splits = values.get("splits")
@@ -279,6 +294,11 @@ def check_splits(cls, values):
279294

280295
@classmethod
281296
def check_removed(cls, values):
297+
"""
298+
Checks and validates the removed units in the curation model.
299+
If `removed` is None, it initializes it as an empty list.
300+
It then checks that each unit ID in `removed` exists in the `unit_ids` list.
301+
"""
282302
unit_ids = list(values["unit_ids"])
283303
removed = values.get("removed")
284304
if removed is None:
@@ -293,6 +313,11 @@ def check_removed(cls, values):
293313

294314
@classmethod
295315
def convert_old_format(cls, values):
316+
"""
317+
Converts old curation formats (v0 and v1) to the current format (v2).
318+
v0 (sortingview) format is converted to v2 by extracting labels, merges, and unit IDs.
319+
v1 format is updated to v2 by renaming fields and ensuring the structure matches the v2 format.
320+
"""
296321
format_version = values.get("format_version", "0")
297322
if format_version == "0":
298323
print("Conversion from format version v0 (sortingview) to v2")

src/spikeinterface/curation/tests/sv-sorting-curation-int.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"reject"
1111
],
1212
"4": [
13-
"noise"
13+
"reject"
1414
],
1515
"5": [
1616
"accept"

src/spikeinterface/curation/tests/sv-sorting-curation-str.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"reject"
1111
],
1212
"d": [
13-
"noise"
13+
"reject"
1414
],
1515
"e": [
1616
"accept"

0 commit comments

Comments
 (0)