Skip to content

Commit c64b840

Browse files
committed
Implement feedback
1 parent a9ed838 commit c64b840

4 files changed

Lines changed: 72 additions & 54 deletions

File tree

src/spikeinterface/curation/curation_format.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def apply_curation_labels(
128128
sorting.set_property(key, all_values)
129129

130130
for new_unit_id, merge in zip(new_unit_ids, curation_model.merges):
131-
old_group_ids = merge.merge_unit_group
131+
old_group_ids = merge.unit_ids
132132
for label_key, label_def in curation_model.label_definitions.items():
133133
if label_def.exclusive:
134134
group_values = []
@@ -221,7 +221,7 @@ def apply_curation(
221221
if len(curation_model.merges) > 0:
222222
sorting, _, new_unit_ids = apply_merges_to_sorting(
223223
sorting,
224-
merge_unit_groups=[m.merge_unit_group for m in curation_model.merges],
224+
merge_unit_groups=[m.unit_ids for m in curation_model.merges],
225225
censor_ms=censor_ms,
226226
return_extra=True,
227227
new_id_strategy=new_id_strategy,
@@ -237,7 +237,7 @@ def apply_curation(
237237
analyzer = analyzer.remove_units(curation_model.removed)
238238
if len(curation_model.removed) > 0:
239239
analyzer, new_unit_ids = analyzer.merge_units(
240-
merge_unit_groups=[m.merge_unit_group for m in curation_model.merges],
240+
merge_unit_groups=[m.unit_ids for m in curation_model.merges],
241241
censor_ms=censor_ms,
242242
merging_mode=merging_mode,
243243
sparsity_overlap=sparsity_overlap,

src/spikeinterface/curation/curation_model.py

Lines changed: 60 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,31 @@ class ManualLabel(BaseModel):
1616

1717

1818
class Merge(BaseModel):
19-
merge_unit_group: List[Union[int, str]] = Field(..., description="List of groups of units to be merged")
20-
merge_new_unit_id: Optional[Union[int, str]] = Field(default=None, description="New unit IDs for the merge group")
19+
unit_ids: List[Union[int, str]] = Field(..., description="List of unit ids to be merged")
20+
new_unit_id: Optional[Union[int, str]] = Field(default=None, description="New unit IDs for the merge group")
2121

2222

2323
class Split(BaseModel):
2424
unit_id: Union[int, str] = Field(..., description="ID of the unit")
25-
split_mode: Literal["indices", "labels"] = Field(
25+
mode: Literal["indices", "labels"] = Field(
2626
default="indices",
2727
description=(
2828
"Mode of the split. The split can be defined by indices or labels. "
2929
"If indices, the split is defined by the a list of lists of indices of spikes within spikes "
30-
"belonging to the unit (`split_indices`). "
31-
"If labels, the split is defined by a list of labels for each spike (`split_labels`). "
30+
"belonging to the unit (`indices`). "
31+
"If labels, the split is defined by a list of labels for each spike (`labels`). "
3232
),
3333
)
34-
split_indices: Optional[Union[List[List[int]]]] = Field(default=None, description="List of indices for the split")
35-
split_labels: Optional[List[int]] = Field(default=None, description="List of labels for the split")
36-
split_new_unit_ids: Optional[List[Union[int, str]]] = Field(
34+
indices: Optional[Union[List[int], List[List[int]]]] = Field(
35+
default=None,
36+
description=(
37+
"List of indices for the split. If a list of indices, the unit is splt in 2 (provided indices/others). "
38+
"If a list of lists, the unit is split in multiple groups (one for each list of indices), plus an optional "
39+
"extra if the spike train has more spikes than the sum of the indices in the lists."
40+
),
41+
)
42+
labels: Optional[List[int]] = Field(default=None, description="List of labels for the split")
43+
new_unit_ids: Optional[List[Union[int, str]]] = Field(
3744
default=None, description="List of new unit IDs for each split"
3845
)
3946

@@ -129,25 +136,36 @@ def check_merges(cls, values):
129136
# Validate merges
130137
for merge in merges:
131138
# Check unit ids exist
132-
for unit_id in merge.merge_unit_group:
139+
for unit_id in merge.unit_ids:
133140
if unit_id not in unit_ids:
134141
raise ValueError(f"Merge unit group unit_id {unit_id} is not in the unit list")
135142

136143
# Check minimum group size
137-
if len(merge.merge_unit_group) < 2:
144+
if len(merge.unit_ids) < 2:
138145
raise ValueError("Merge unit groups must have at least 2 elements")
139146

140147
# Check new unit id not already used
141-
if merge.merge_new_unit_id is not None:
142-
if merge.merge_new_unit_id in unit_ids:
143-
raise ValueError(f"New unit ID {merge.merge_new_unit_id} is already in the unit list")
148+
if merge.new_unit_id is not None:
149+
if merge.new_unit_id in unit_ids:
150+
raise ValueError(f"New unit ID {merge.new_unit_id} is already in the unit list")
144151

145152
values["merges"] = merges
146153
return values
147154

148155
@classmethod
149156
def check_splits(cls, values):
150-
157+
"""
158+
Checks and validates the splits in the curation model.
159+
If `splits` is a dictionary with unit_id as key and split data as values,
160+
it converts it to a list of Split objects.
161+
Each Split object is then validated:
162+
- Checks if the unit_id exists in the unit_ids list.
163+
- Validates the mode (indices or labels).
164+
- If mode is indices, checks that indices are defined and not empty, and that there are no duplicate indices.
165+
- If mode is labels, checks that labels are defined and not empty.
166+
- Validates new unit IDs if provided, ensuring they are not already in the unit_ids list and match the
167+
number of splits.
168+
"""
151169
unit_ids = list(values["unit_ids"])
152170
splits = values.get("splits")
153171
if splits is None:
@@ -162,12 +180,12 @@ def check_splits(cls, values):
162180
split_list.append(
163181
{
164182
"unit_id": unit_id,
165-
"split_mode": "indices",
166-
"split_indices": [list(indices) for indices in split_data],
183+
"mode": "indices",
184+
"indices": [list(indices) for indices in split_data],
167185
}
168186
)
169187
else:
170-
split_list.append({"unit_id": unit_id, "split_mode": "labels", "split_labels": list(split_data)})
188+
split_list.append({"unit_id": unit_id, "mode": "labels", "labels": list(split_data)})
171189
splits = split_list
172190

173191
# Make a copy of the list
@@ -177,12 +195,12 @@ def check_splits(cls, values):
177195
for i, split in enumerate(splits):
178196
if isinstance(split, dict):
179197
split = dict(split)
180-
if "split_indices" in split:
181-
split["split_indices"] = [list(indices) for indices in split["split_indices"]]
182-
if "split_labels" in split:
183-
split["split_labels"] = list(split["split_labels"])
184-
if "split_new_unit_ids" in split:
185-
split["split_new_unit_ids"] = list(split["split_new_unit_ids"])
198+
if "indices" in split:
199+
split["indices"] = [list(indices) for indices in split["indices"]]
200+
if "labels" in split:
201+
split["labels"] = list(split["labels"])
202+
if "new_unit_ids" in split:
203+
split["new_unit_ids"] = list(split["new_unit_ids"])
186204
splits[i] = Split(**split)
187205

188206
# Validate splits
@@ -192,36 +210,36 @@ def check_splits(cls, values):
192210
raise ValueError(f"Split unit_id {split.unit_id} is not in the unit list")
193211

194212
# Validate based on mode
195-
if split.split_mode == "indices":
196-
if split.split_indices is None:
197-
raise ValueError(f"Split unit {split.unit_id} has no split_indices defined")
198-
if len(split.split_indices) < 1:
199-
raise ValueError(f"Split unit {split.unit_id} has empty split_indices")
213+
if split.mode == "indices":
214+
if split.indices is None:
215+
raise ValueError(f"Split unit {split.unit_id} has no indices defined")
216+
if len(split.indices) < 1:
217+
raise ValueError(f"Split unit {split.unit_id} has empty indices")
200218
# Check no duplicate indices
201-
all_indices = list(chain.from_iterable(split.split_indices))
219+
all_indices = list(chain.from_iterable(split.indices))
202220
if len(all_indices) != len(set(all_indices)):
203221
raise ValueError(f"Split unit {split.unit_id} has duplicate indices")
204222

205-
elif split.split_mode == "labels":
206-
if split.split_labels is None:
207-
raise ValueError(f"Split unit {split.unit_id} has no split_labels defined")
208-
if len(split.split_labels) == 0:
209-
raise ValueError(f"Split unit {split.unit_id} has empty split_labels")
223+
elif split.mode == "labels":
224+
if split.labels is None:
225+
raise ValueError(f"Split unit {split.unit_id} has no labels defined")
226+
if len(split.labels) == 0:
227+
raise ValueError(f"Split unit {split.unit_id} has empty labels")
210228

211229
# Validate new unit IDs
212-
if split.split_new_unit_ids is not None:
213-
if split.split_mode == "indices":
214-
if len(split.split_new_unit_ids) != len(split.split_indices):
230+
if split.new_unit_ids is not None:
231+
if split.mode == "indices":
232+
if len(split.new_unit_ids) != len(split.indices):
215233
raise ValueError(
216234
f"Number of new unit IDs does not match number of splits for unit {split.unit_id}"
217235
)
218-
elif split.split_mode == "labels":
219-
if len(split.split_new_unit_ids) != len(set(split.split_labels)):
236+
elif split.mode == "labels":
237+
if len(split.new_unit_ids) != len(set(split.labels)):
220238
raise ValueError(
221239
f"Number of new unit IDs does not match number of unique labels for unit {split.unit_id}"
222240
)
223241

224-
for new_id in split.split_new_unit_ids:
242+
for new_id in split.new_unit_ids:
225243
if new_id in unit_ids:
226244
raise ValueError(f"New unit ID {new_id} is already in the unit list")
227245

@@ -312,7 +330,7 @@ def validate_curation_dict(cls, values):
312330

313331
labeled_unit_set = set([lbl.unit_id for lbl in values.manual_labels]) if values.manual_labels else set()
314332
merged_units_set = (
315-
set(chain.from_iterable(merge.merge_unit_group for merge in values.merges)) if values.merges else set()
333+
set(chain.from_iterable(merge.unit_ids for merge in values.merges)) if values.merges else set()
316334
)
317335
split_units_set = set(split.unit_id for split in values.splits) if values.splits else set()
318336
removed_set = set(values.removed) if values.removed else set()
@@ -329,7 +347,7 @@ def validate_curation_dict(cls, values):
329347
raise ValueError("Curation format: some removed units are not in the unit list")
330348

331349
# Check for units being merged multiple times
332-
all_merging_groups = [set(merge.merge_unit_group) for merge in values.merges] if values.merges else []
350+
all_merging_groups = [set(merge.unit_ids) for merge in values.merges] if values.merges else []
333351
for gp_1, gp_2 in combinations(all_merging_groups, 2):
334352
if len(gp_1.intersection(gp_2)) != 0:
335353
raise ValueError("Curation format: some units belong to multiple merge groups")

src/spikeinterface/curation/sortingview_curation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def apply_sortingview_curation(
112112
clean_merges = []
113113
for merge in curation_model.merges:
114114
clean_merge = []
115-
for unit_id in merge.merge_unit_group:
115+
for unit_id in merge.unit_ids:
116116
if unit_id not in curation_model.removed:
117117
clean_merge.append(unit_id)
118118
if len(clean_merge) > 1:

src/spikeinterface/curation/tests/test_curation_model.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,15 @@ def test_merge_units():
9898

9999
model = CurationModel(**valid_merge)
100100
assert len(model.merges) == 2
101-
assert model.merges[0].merge_new_unit_id == 5
102-
assert model.merges[1].merge_new_unit_id == 6
101+
assert model.merges[0].new_unit_id == 5
102+
assert model.merges[1].new_unit_id == 6
103103

104104
# Test dictionary format
105105
valid_merge_dict = {"format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": {5: [1, 2], 6: [3, 4]}}
106106

107107
model = CurationModel(**valid_merge_dict)
108108
assert len(model.merges) == 2
109-
merge_new_ids = {merge.merge_new_unit_id for merge in model.merges}
109+
merge_new_ids = {merge.new_unit_id for merge in model.merges}
110110
assert merge_new_ids == {5, 6}
111111

112112
# Test list format
@@ -158,8 +158,8 @@ def test_split_units():
158158

159159
model = CurationModel(**valid_split_indices)
160160
assert len(model.splits) == 1
161-
assert model.splits[0].split_mode == "indices"
162-
assert len(model.splits[0].split_indices) == 2
161+
assert model.splits[0].mode == "indices"
162+
assert len(model.splits[0].indices) == 2
163163

164164
# Test labels mode with list format
165165
valid_split_labels = {
@@ -172,8 +172,8 @@ def test_split_units():
172172

173173
model = CurationModel(**valid_split_labels)
174174
assert len(model.splits) == 1
175-
assert model.splits[0].split_mode == "labels"
176-
assert len(set(model.splits[0].split_labels)) == 3
175+
assert model.splits[0].mode == "labels"
176+
assert len(set(model.splits[0].labels)) == 3
177177

178178
# Test dictionary format with indices
179179
valid_split_dict = {
@@ -187,7 +187,7 @@ def test_split_units():
187187

188188
model = CurationModel(**valid_split_dict)
189189
assert len(model.splits) == 2
190-
assert all(split.split_mode == "indices" for split in model.splits)
190+
assert all(split.mode == "indices" for split in model.splits)
191191

192192
# Test invalid unit ID
193193
invalid_unit_id = {

0 commit comments

Comments
 (0)