Skip to content

Commit 3abdda8

Browse files
authored
Rename CurationModel to Curation (#4421)
1 parent 747ee14 commit 3abdda8

7 files changed

Lines changed: 70 additions & 61 deletions

File tree

doc/api.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ such as merging and splitting units, as well as defining labels for units.
390390

391391
.. automodule:: spikeinterface.curation.curation_model
392392

393-
.. autopydantic_model:: CurationModel
393+
.. autopydantic_model:: Curation
394394
.. autopydantic_model:: Merge
395395
.. autopydantic_model:: Split
396396
.. autopydantic_model:: ManualLabel

src/spikeinterface/curation/curation_format.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from itertools import chain
55

66
from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting, apply_splits_to_sorting
7-
from spikeinterface.curation.curation_model import CurationModel, SequentialCuration
7+
from spikeinterface.curation.curation_model import Curation, SequentialCuration
88

99

1010
def validate_curation_dict(curation_dict: dict):
@@ -19,10 +19,10 @@ def validate_curation_dict(curation_dict: dict):
1919
2020
"""
2121
# this will validate the format of the curation_dict
22-
CurationModel(**curation_dict)
22+
Curation(**curation_dict)
2323

2424

25-
def curation_label_to_vectors(curation_dict_or_model: dict | CurationModel):
25+
def curation_label_to_vectors(curation_dict_or_model: dict | Curation):
2626
"""
2727
Transform the curation dict into dict of vectors.
2828
For label category with exclusive=True : a column is created and values are the unique label.
@@ -32,7 +32,7 @@ def curation_label_to_vectors(curation_dict_or_model: dict | CurationModel):
3232
3333
Parameters
3434
----------
35-
curation_dict : dict or CurationModel
35+
curation_dict : dict or Curation
3636
A curation dictionary or model
3737
3838
Returns
@@ -41,7 +41,7 @@ def curation_label_to_vectors(curation_dict_or_model: dict | CurationModel):
4141
4242
"""
4343
if isinstance(curation_dict_or_model, dict):
44-
curation_model = CurationModel(**curation_dict_or_model)
44+
curation_model = Curation(**curation_dict_or_model)
4545
else:
4646
curation_model = curation_dict_or_model
4747
unit_ids = list(curation_model.unit_ids)
@@ -71,7 +71,7 @@ def curation_label_to_vectors(curation_dict_or_model: dict | CurationModel):
7171
return labels
7272

7373

74-
def curation_label_to_dataframe(curation_dict_or_model: dict | CurationModel):
74+
def curation_label_to_dataframe(curation_dict_or_model: dict | Curation):
7575
"""
7676
Transform the curation dict into a pandas dataframe.
7777
For label category with exclusive=True : a column is created and values are the unique label.
@@ -92,17 +92,15 @@ def curation_label_to_dataframe(curation_dict_or_model: dict | CurationModel):
9292
import pandas as pd
9393

9494
if isinstance(curation_dict_or_model, dict):
95-
curation_model = CurationModel(**curation_dict_or_model)
95+
curation_model = Curation(**curation_dict_or_model)
9696
else:
9797
curation_model = curation_dict_or_model
9898

9999
labels = pd.DataFrame(curation_label_to_vectors(curation_model), index=curation_model.unit_ids)
100100
return labels
101101

102102

103-
def apply_curation_labels(
104-
sorting_or_analyzer: BaseSorting | SortingAnalyzer, curation_dict_or_model: dict | CurationModel
105-
):
103+
def apply_curation_labels(sorting_or_analyzer: BaseSorting | SortingAnalyzer, curation_dict_or_model: dict | Curation):
106104
"""
107105
Apply manual labels after merges/splits.
108106
@@ -113,7 +111,7 @@ def apply_curation_labels(
113111
* for split units, the original label is applied to all split units
114112
"""
115113
if isinstance(curation_dict_or_model, dict):
116-
curation_model = CurationModel(**curation_dict_or_model)
114+
curation_model = Curation(**curation_dict_or_model)
117115
else:
118116
curation_model = curation_dict_or_model
119117

@@ -136,7 +134,7 @@ def apply_curation_labels(
136134

137135
def apply_curation(
138136
sorting_or_analyzer: BaseSorting | SortingAnalyzer,
139-
curation_dict_or_model: dict | list | CurationModel | SequentialCuration,
137+
curation_dict_or_model: dict | list | Curation | SequentialCuration,
140138
censor_ms: float | None = None,
141139
new_id_strategy: str = "append",
142140
merging_mode: str = "soft",
@@ -162,7 +160,7 @@ def apply_curation(
162160
----------
163161
sorting_or_analyzer : Sorting | SortingAnalyzer
164162
The Sorting or SortingAnalyzer object to apply merges.
165-
curation_dict : dict | CurationModel | SequentialCuration
163+
curation_dict : dict | Curation | SequentialCuration
166164
The curation dict or model.
167165
censor_ms : float | None, default: None
168166
When applying the merges, any consecutive spikes within the `censor_ms` are removed. This can be thought of
@@ -197,10 +195,10 @@ def apply_curation(
197195
sorting_or_analyzer, (BaseSorting, SortingAnalyzer)
198196
), f"`sorting_or_analyzer` must be a Sorting or a SortingAnalyzer, not an object of type {type(sorting_or_analyzer)}"
199197
assert isinstance(
200-
curation_dict_or_model, (dict, list, CurationModel, SequentialCuration)
201-
), f"`curation_dict_or_model` must be a dict, CurationModel or a SequentialCuration not an object of type {type(curation_dict_or_model)}"
198+
curation_dict_or_model, (dict, list, Curation, SequentialCuration)
199+
), f"`curation_dict_or_model` must be a dict, Curation or a SequentialCuration not an object of type {type(curation_dict_or_model)}"
202200
if isinstance(curation_dict_or_model, dict):
203-
curation_model = CurationModel(**curation_dict_or_model)
201+
curation_model = Curation(**curation_dict_or_model)
204202
elif isinstance(curation_dict_or_model, list):
205203
curation_model = SequentialCuration(curation_steps=curation_dict_or_model)
206204
else:
@@ -298,7 +296,7 @@ def apply_curation(
298296
return curated_sorting_or_analyzer
299297

300298

301-
def load_curation(curation_path: str | Path) -> CurationModel:
299+
def load_curation(curation_path: str | Path) -> Curation:
302300
"""
303301
Loads a curation from a local json file.
304302
@@ -309,9 +307,9 @@ def load_curation(curation_path: str | Path) -> CurationModel:
309307
310308
Returns
311309
-------
312-
curation_model : CurationModel
313-
A CurationModel object
310+
curation_model : Curation
311+
A Curation object
314312
"""
315313
with open(curation_path) as f:
316314
curation_dict = json.load(f)
317-
return CurationModel(**curation_dict)
315+
return Curation(**curation_dict)

src/spikeinterface/curation/curation_model.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from pydantic import BaseModel, Field, model_validator, field_validator, field_serializer
1+
import warnings
22
from typing import Literal, List
33
from itertools import chain, combinations
4+
from pydantic import BaseModel, Field, model_validator, field_validator, field_serializer
5+
46
import numpy as np
57

68
from spikeinterface import BaseSorting
@@ -74,7 +76,7 @@ def get_full_spike_indices(self, sorting: BaseSorting):
7476
return full_spike_indices
7577

7678

77-
class CurationModel(BaseModel):
79+
class Curation(BaseModel):
7880
supported_versions: tuple[Literal["1"], Literal["2"]] = Field(
7981
default=("1", "2"), description="Supported versions of the curation format"
8082
)
@@ -477,14 +479,23 @@ def validate_curation_dict(self):
477479
return self
478480

479481

482+
def CurationModel(*args, **kwargs):
483+
warnings.warn(
484+
"`CurationModel` is deprecated and will be removed in 0.105.0. Use `Curation` instead",
485+
DeprecationWarning,
486+
stacklevel=2,
487+
)
488+
return Curation(*args, **kwargs)
489+
490+
480491
class SequentialCuration(BaseModel):
481492
"""
482493
A Pydantic model which defines a sequence of curation steps. If using sequential curations,
483494
we demand that each individual curation (except the final one) has manually defined new unit ids,
484495
and that these match the unit ids of the following curation.
485496
"""
486497

487-
curation_steps: List[CurationModel] = Field(description="List of curation steps applied sequentially")
498+
curation_steps: List[Curation] = Field(description="List of curation steps applied sequentially")
488499

489500
@model_validator(mode="after")
490501
def validate_sequential_curation(self):

src/spikeinterface/curation/sortingview_curation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
apply_curation,
1010
curation_label_to_vectors,
1111
)
12-
from .curation_model import CurationModel, Merge
12+
from .curation_model import Curation, Merge
1313

1414

1515
def get_kachery():
@@ -83,7 +83,7 @@ def apply_sortingview_curation(
8383

8484
unit_ids = sorting_or_analyzer.unit_ids
8585
curation_dict["unit_ids"] = unit_ids
86-
curation_model = CurationModel(**curation_dict)
86+
curation_model = Curation(**curation_dict)
8787

8888
if skip_merge:
8989
curation_model.merges = []

src/spikeinterface/curation/tests/test_curation_model.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@
33
from pydantic import ValidationError
44
import numpy as np
55

6-
from spikeinterface.curation.curation_model import CurationModel, SequentialCuration, LabelDefinition
6+
from spikeinterface.curation.curation_model import Curation, SequentialCuration, LabelDefinition
77

88

99
# Test data for format version
1010
def test_format_version():
1111
# Valid format version
12-
CurationModel(format_version="1", unit_ids=[1, 2, 3])
12+
Curation(format_version="1", unit_ids=[1, 2, 3])
1313

1414
# Invalid format version
1515
with pytest.raises(ValidationError):
16-
CurationModel(format_version="3", unit_ids=[1, 2, 3])
16+
Curation(format_version="3", unit_ids=[1, 2, 3])
1717
with pytest.raises(ValidationError):
18-
CurationModel(format_version="0.1", unit_ids=[1, 2, 3])
18+
Curation(format_version="0.1", unit_ids=[1, 2, 3])
1919

2020

2121
# Test data for label definitions
@@ -29,7 +29,7 @@ def test_label_definitions():
2929
},
3030
}
3131

32-
model = CurationModel(**valid_label_def)
32+
model = Curation(**valid_label_def)
3333
assert "quality" in model.label_definitions
3434
assert model.label_definitions["quality"].name == "quality"
3535
assert model.label_definitions["quality"].exclusive is True
@@ -54,7 +54,7 @@ def test_manual_labels():
5454
],
5555
}
5656

57-
model = CurationModel(**valid_labels)
57+
model = Curation(**valid_labels)
5858
assert len(model.manual_labels) == 2
5959

6060
# Test invalid unit ID
@@ -67,7 +67,7 @@ def test_manual_labels():
6767
"manual_labels": [{"unit_id": 4, "labels": {"quality": ["good"]}}], # Non-existent unit
6868
}
6969
with pytest.raises(ValidationError):
70-
CurationModel(**invalid_unit)
70+
Curation(**invalid_unit)
7171

7272
# Test violation of exclusive label
7373
invalid_exclusive = {
@@ -81,7 +81,7 @@ def test_manual_labels():
8181
],
8282
}
8383
with pytest.raises(ValidationError):
84-
CurationModel(**invalid_exclusive)
84+
Curation(**invalid_exclusive)
8585

8686

8787
# Test merge functionality
@@ -96,15 +96,15 @@ def test_merge_units():
9696
],
9797
}
9898

99-
model = CurationModel(**valid_merge)
99+
model = Curation(**valid_merge)
100100
assert len(model.merges) == 2
101101
assert model.merges[0].new_unit_id == 5
102102
assert model.merges[1].new_unit_id == 6
103103

104104
# Test dictionary format
105105
valid_merge_dict = {"format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": {5: [1, 2], 6: [3, 4]}}
106106

107-
model = CurationModel(**valid_merge_dict)
107+
model = Curation(**valid_merge_dict)
108108
assert len(model.merges) == 2
109109
merge_new_ids = {merge.new_unit_id for merge in model.merges}
110110
assert merge_new_ids == {5, 6}
@@ -115,7 +115,7 @@ def test_merge_units():
115115
"unit_ids": [1, 2, 3, 4],
116116
"merges": [[1, 2], [3, 4]], # Merge each pair into a new unit
117117
}
118-
model = CurationModel(**valid_merge_list)
118+
model = Curation(**valid_merge_list)
119119
assert len(model.merges) == 2
120120

121121
# Test invalid merge group (single unit)
@@ -125,7 +125,7 @@ def test_merge_units():
125125
"merges": [{"unit_ids": [1], "new_unit_id": 4}],
126126
}
127127
with pytest.raises(ValidationError):
128-
CurationModel(**invalid_merge_group)
128+
Curation(**invalid_merge_group)
129129

130130
# Test overlapping merge groups
131131
invalid_overlap = {
@@ -137,7 +137,7 @@ def test_merge_units():
137137
],
138138
}
139139
with pytest.raises(ValidationError):
140-
CurationModel(**invalid_overlap)
140+
Curation(**invalid_overlap)
141141

142142

143143
# Test split functionality
@@ -156,7 +156,7 @@ def test_split_units():
156156
],
157157
}
158158

159-
model = CurationModel(**valid_split_indices)
159+
model = Curation(**valid_split_indices)
160160
assert len(model.splits) == 1
161161
assert model.splits[0].mode == "indices"
162162
assert len(model.splits[0].indices) == 2
@@ -168,7 +168,7 @@ def test_split_units():
168168
"splits": [{"unit_id": 1, "mode": "labels", "labels": [0, 0, 1, 1, 0, 2], "new_unit_ids": [4, 5, 6]}],
169169
}
170170

171-
model = CurationModel(**valid_split_labels)
171+
model = Curation(**valid_split_labels)
172172
assert len(model.splits) == 1
173173
assert model.splits[0].mode == "labels"
174174
assert len(set(model.splits[0].labels)) == 3
@@ -183,7 +183,7 @@ def test_split_units():
183183
},
184184
}
185185

186-
model = CurationModel(**valid_split_dict)
186+
model = Curation(**valid_split_dict)
187187
assert len(model.splits) == 2
188188
assert all(split.mode == "indices" for split in model.splits)
189189

@@ -194,7 +194,7 @@ def test_split_units():
194194
"splits": [{"unit_id": 4, "mode": "indices", "indices": [[0, 1], [2, 3]]}], # Non-existent unit
195195
}
196196
with pytest.raises(ValidationError):
197-
CurationModel(**invalid_unit_id)
197+
Curation(**invalid_unit_id)
198198

199199
# Test invalid new unit IDs count for indices mode
200200
invalid_new_ids = {
@@ -210,20 +210,20 @@ def test_split_units():
210210
],
211211
}
212212
with pytest.raises(ValidationError):
213-
CurationModel(**invalid_new_ids)
213+
Curation(**invalid_new_ids)
214214

215215

216216
# Test removed units
217217
def test_removed_units():
218218
valid_remove = {"format_version": "2", "unit_ids": [1, 2, 3], "removed": [2]}
219219

220-
model = CurationModel(**valid_remove)
220+
model = Curation(**valid_remove)
221221
assert len(model.removed) == 1
222222

223223
# Test removing non-existent unit
224224
invalid_remove = {"format_version": "2", "unit_ids": [1, 2, 3], "removed": [4]} # Non-existent unit
225225
with pytest.raises(ValidationError):
226-
CurationModel(**invalid_remove)
226+
Curation(**invalid_remove)
227227

228228
# Test conflict between merge and remove
229229
invalid_merge_remove = {
@@ -233,7 +233,7 @@ def test_removed_units():
233233
"removed": [1], # Unit is both merged and removed
234234
}
235235
with pytest.raises(ValidationError):
236-
CurationModel(**invalid_merge_remove)
236+
Curation(**invalid_merge_remove)
237237

238238

239239
# Test complete model with multiple operations
@@ -251,7 +251,7 @@ def test_complete_model():
251251
"removed": [5],
252252
}
253253

254-
model = CurationModel(**complete_model)
254+
model = Curation(**complete_model)
255255
assert model.format_version == "2"
256256
assert len(model.unit_ids) == 5
257257
assert len(model.label_definitions) == 2
@@ -274,7 +274,7 @@ def test_complete_model():
274274
"removed": [5],
275275
}
276276

277-
model = CurationModel(**complete_model_dict)
277+
model = Curation(**complete_model_dict)
278278
assert model.format_version == "2"
279279
assert len(model.unit_ids) == 5
280280
assert len(model.label_definitions) == 2

0 commit comments

Comments
 (0)