Skip to content

Commit 317f87c

Browse files
committed
merge_new_unit_ids -> merge_new_unit_id
1 parent 4f14e90 commit 317f87c

2 files changed

Lines changed: 19 additions & 21 deletions

File tree

src/spikeinterface/curation/curation_model.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class ManualLabel(BaseModel):
1717

1818
class Merge(BaseModel):
1919
merge_unit_group: List[Union[int, str]] = Field(..., description="List of groups of units to be merged")
20-
merge_new_unit_ids: Optional[Union[int, str]] = Field(default=None, description="New unit IDs for the merge group")
20+
merge_new_unit_id: Optional[Union[int, str]] = Field(default=None, description="New unit IDs for the merge group")
2121

2222

2323
class Split(BaseModel):
@@ -67,7 +67,7 @@ def add_label_definition_name(cls, label_definitions):
6767

6868
@classmethod
6969
def check_manual_labels(cls, values):
70-
values = dict(values)
70+
7171
unit_ids = list(values["unit_ids"])
7272
manual_labels = values.get("manual_labels")
7373
if manual_labels is None:
@@ -99,7 +99,7 @@ def check_manual_labels(cls, values):
9999

100100
@classmethod
101101
def check_merges(cls, values):
102-
values = dict(values)
102+
103103
unit_ids = list(values["unit_ids"])
104104
merges = values.get("merges")
105105
if merges is None:
@@ -110,7 +110,7 @@ def check_merges(cls, values):
110110
# Convert dict format to list of Merge objects
111111
merge_list = []
112112
for merge_new_id, merge_group in merges.items():
113-
merge_list.append({"merge_unit_group": list(merge_group), "merge_new_unit_ids": merge_new_id})
113+
merge_list.append({"merge_unit_group": list(merge_group), "merge_new_unit_id": merge_new_id})
114114
merges = merge_list
115115

116116
# Make a copy of the list
@@ -138,16 +138,16 @@ def check_merges(cls, values):
138138
raise ValueError("Merge unit groups must have at least 2 elements")
139139

140140
# Check new unit id not already used
141-
if merge.merge_new_unit_ids is not None:
142-
if merge.merge_new_unit_ids in unit_ids:
143-
raise ValueError(f"New unit ID {merge.merge_new_unit_ids} is already in the unit list")
141+
if merge.merge_new_unit_id is not None:
142+
if merge.merge_new_unit_id in unit_ids:
143+
raise ValueError(f"New unit ID {merge.merge_new_unit_id} is already in the unit list")
144144

145145
values["merges"] = merges
146146
return values
147147

148148
@classmethod
149149
def check_splits(cls, values):
150-
values = dict(values)
150+
151151
unit_ids = list(values["unit_ids"])
152152
splits = values.get("splits")
153153
if splits is None:
@@ -230,7 +230,6 @@ def check_splits(cls, values):
230230

231231
@classmethod
232232
def check_removed(cls, values):
233-
values = dict(values)
234233
unit_ids = list(values["unit_ids"])
235234
removed = values.get("removed")
236235
if removed is None:
@@ -246,8 +245,6 @@ def check_removed(cls, values):
246245
@classmethod
247246
def convert_old_format(cls, values):
248247
format_version = values.get("format_version", "0")
249-
if format_version != "2":
250-
values = dict(values)
251248
if format_version == "0":
252249
print("Conversion from format version v0 (sortingview) to v2")
253250
if "mergeGroups" not in values.keys():
@@ -298,6 +295,7 @@ def convert_old_format(cls, values):
298295
@model_validator(mode="before")
299296
def validate_fields(cls, values):
300297
values = dict(values)
298+
values["label_definitions"] = values.get("label_definitions", {})
301299
values = cls.convert_old_format(values)
302300
values = cls.check_manual_labels(values)
303301
values = cls.check_merges(values)

src/spikeinterface/curation/tests/test_curation_model.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,22 +91,22 @@ def test_merge_units():
9191
"format_version": "2",
9292
"unit_ids": [1, 2, 3, 4],
9393
"merges": [
94-
{"merge_unit_group": [1, 2], "merge_new_unit_ids": 5},
95-
{"merge_unit_group": [3, 4], "merge_new_unit_ids": 6},
94+
{"merge_unit_group": [1, 2], "merge_new_unit_id": 5},
95+
{"merge_unit_group": [3, 4], "merge_new_unit_id": 6},
9696
],
9797
}
9898

9999
model = CurationModel(**valid_merge)
100100
assert len(model.merges) == 2
101-
assert model.merges[0].merge_new_unit_ids == 5
102-
assert model.merges[1].merge_new_unit_ids == 6
101+
assert model.merges[0].merge_new_unit_id == 5
102+
assert model.merges[1].merge_new_unit_id == 6
103103

104104
# Test dictionary format
105105
valid_merge_dict = {"format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": {5: [1, 2], 6: [3, 4]}}
106106

107107
model = CurationModel(**valid_merge_dict)
108108
assert len(model.merges) == 2
109-
merge_new_ids = {merge.merge_new_unit_ids for merge in model.merges}
109+
merge_new_ids = {merge.merge_new_unit_id for merge in model.merges}
110110
assert merge_new_ids == {5, 6}
111111

112112
# Test list format
@@ -122,7 +122,7 @@ def test_merge_units():
122122
invalid_merge_group = {
123123
"format_version": "2",
124124
"unit_ids": [1, 2, 3],
125-
"merges": [{"merge_unit_group": [1], "merge_new_unit_ids": 4}],
125+
"merges": [{"merge_unit_group": [1], "merge_new_unit_id": 4}],
126126
}
127127
with pytest.raises(ValidationError):
128128
CurationModel(**invalid_merge_group)
@@ -132,8 +132,8 @@ def test_merge_units():
132132
"format_version": "2",
133133
"unit_ids": [1, 2, 3],
134134
"merges": [
135-
{"merge_unit_group": [1, 2], "merge_new_unit_ids": 4},
136-
{"merge_unit_group": [2, 3], "merge_new_unit_ids": 5},
135+
{"merge_unit_group": [1, 2], "merge_new_unit_id": 4},
136+
{"merge_unit_group": [2, 3], "merge_new_unit_id": 5},
137137
],
138138
}
139139
with pytest.raises(ValidationError):
@@ -231,7 +231,7 @@ def test_removed_units():
231231
invalid_merge_remove = {
232232
"format_version": "2",
233233
"unit_ids": [1, 2, 3],
234-
"merges": [{"merge_unit_group": [1, 2], "merge_new_unit_ids": 4}],
234+
"merges": [{"merge_unit_group": [1, 2], "merge_new_unit_id": 4}],
235235
"removed": [1], # Unit is both merged and removed
236236
}
237237
with pytest.raises(ValidationError):
@@ -248,7 +248,7 @@ def test_complete_model():
248248
"tags": LabelDefinition(name="tags", label_options=["burst", "slow"], exclusive=False),
249249
},
250250
"manual_labels": [{"unit_id": 1, "labels": {"quality": ["good"], "tags": ["burst"]}}],
251-
"merges": [{"merge_unit_group": [2, 3], "merge_new_unit_ids": 6}],
251+
"merges": [{"merge_unit_group": [2, 3], "merge_new_unit_id": 6}],
252252
"splits": [
253253
{"unit_id": 4, "split_mode": "indices", "split_indices": [[0, 1], [2, 3]], "split_new_unit_ids": [7, 8]}
254254
],

0 commit comments

Comments
 (0)