33from pydantic import ValidationError
44import numpy as np
55
6- from spikeinterface .curation .curation_model import CurationModel , SequentialCuration , LabelDefinition
6+ from spikeinterface .curation .curation_model import Curation , SequentialCuration , LabelDefinition
77
88
99# Test data for format version
1010def test_format_version ():
1111 # Valid format version
12- CurationModel (format_version = "1" , unit_ids = [1 , 2 , 3 ])
12+ Curation (format_version = "1" , unit_ids = [1 , 2 , 3 ])
1313
1414 # Invalid format version
1515 with pytest .raises (ValidationError ):
16- CurationModel (format_version = "3" , unit_ids = [1 , 2 , 3 ])
16+ Curation (format_version = "3" , unit_ids = [1 , 2 , 3 ])
1717 with pytest .raises (ValidationError ):
18- CurationModel (format_version = "0.1" , unit_ids = [1 , 2 , 3 ])
18+ Curation (format_version = "0.1" , unit_ids = [1 , 2 , 3 ])
1919
2020
2121# Test data for label definitions
@@ -29,7 +29,7 @@ def test_label_definitions():
2929 },
3030 }
3131
32- model = CurationModel (** valid_label_def )
32+ model = Curation (** valid_label_def )
3333 assert "quality" in model .label_definitions
3434 assert model .label_definitions ["quality" ].name == "quality"
3535 assert model .label_definitions ["quality" ].exclusive is True
@@ -54,7 +54,7 @@ def test_manual_labels():
5454 ],
5555 }
5656
57- model = CurationModel (** valid_labels )
57+ model = Curation (** valid_labels )
5858 assert len (model .manual_labels ) == 2
5959
6060 # Test invalid unit ID
@@ -67,7 +67,7 @@ def test_manual_labels():
6767 "manual_labels" : [{"unit_id" : 4 , "labels" : {"quality" : ["good" ]}}], # Non-existent unit
6868 }
6969 with pytest .raises (ValidationError ):
70- CurationModel (** invalid_unit )
70+ Curation (** invalid_unit )
7171
7272 # Test violation of exclusive label
7373 invalid_exclusive = {
@@ -81,7 +81,7 @@ def test_manual_labels():
8181 ],
8282 }
8383 with pytest .raises (ValidationError ):
84- CurationModel (** invalid_exclusive )
84+ Curation (** invalid_exclusive )
8585
8686
8787# Test merge functionality
@@ -96,15 +96,15 @@ def test_merge_units():
9696 ],
9797 }
9898
99- model = CurationModel (** valid_merge )
99+ model = Curation (** valid_merge )
100100 assert len (model .merges ) == 2
101101 assert model .merges [0 ].new_unit_id == 5
102102 assert model .merges [1 ].new_unit_id == 6
103103
104104 # Test dictionary format
105105 valid_merge_dict = {"format_version" : "2" , "unit_ids" : [1 , 2 , 3 , 4 ], "merges" : {5 : [1 , 2 ], 6 : [3 , 4 ]}}
106106
107- model = CurationModel (** valid_merge_dict )
107+ model = Curation (** valid_merge_dict )
108108 assert len (model .merges ) == 2
109109 merge_new_ids = {merge .new_unit_id for merge in model .merges }
110110 assert merge_new_ids == {5 , 6 }
@@ -115,7 +115,7 @@ def test_merge_units():
115115 "unit_ids" : [1 , 2 , 3 , 4 ],
116116 "merges" : [[1 , 2 ], [3 , 4 ]], # Merge each pair into a new unit
117117 }
118- model = CurationModel (** valid_merge_list )
118+ model = Curation (** valid_merge_list )
119119 assert len (model .merges ) == 2
120120
121121 # Test invalid merge group (single unit)
@@ -125,7 +125,7 @@ def test_merge_units():
125125 "merges" : [{"unit_ids" : [1 ], "new_unit_id" : 4 }],
126126 }
127127 with pytest .raises (ValidationError ):
128- CurationModel (** invalid_merge_group )
128+ Curation (** invalid_merge_group )
129129
130130 # Test overlapping merge groups
131131 invalid_overlap = {
@@ -137,7 +137,7 @@ def test_merge_units():
137137 ],
138138 }
139139 with pytest .raises (ValidationError ):
140- CurationModel (** invalid_overlap )
140+ Curation (** invalid_overlap )
141141
142142
143143# Test split functionality
@@ -156,7 +156,7 @@ def test_split_units():
156156 ],
157157 }
158158
159- model = CurationModel (** valid_split_indices )
159+ model = Curation (** valid_split_indices )
160160 assert len (model .splits ) == 1
161161 assert model .splits [0 ].mode == "indices"
162162 assert len (model .splits [0 ].indices ) == 2
@@ -168,7 +168,7 @@ def test_split_units():
168168 "splits" : [{"unit_id" : 1 , "mode" : "labels" , "labels" : [0 , 0 , 1 , 1 , 0 , 2 ], "new_unit_ids" : [4 , 5 , 6 ]}],
169169 }
170170
171- model = CurationModel (** valid_split_labels )
171+ model = Curation (** valid_split_labels )
172172 assert len (model .splits ) == 1
173173 assert model .splits [0 ].mode == "labels"
174174 assert len (set (model .splits [0 ].labels )) == 3
@@ -183,7 +183,7 @@ def test_split_units():
183183 },
184184 }
185185
186- model = CurationModel (** valid_split_dict )
186+ model = Curation (** valid_split_dict )
187187 assert len (model .splits ) == 2
188188 assert all (split .mode == "indices" for split in model .splits )
189189
@@ -194,7 +194,7 @@ def test_split_units():
194194 "splits" : [{"unit_id" : 4 , "mode" : "indices" , "indices" : [[0 , 1 ], [2 , 3 ]]}], # Non-existent unit
195195 }
196196 with pytest .raises (ValidationError ):
197- CurationModel (** invalid_unit_id )
197+ Curation (** invalid_unit_id )
198198
199199 # Test invalid new unit IDs count for indices mode
200200 invalid_new_ids = {
@@ -210,20 +210,20 @@ def test_split_units():
210210 ],
211211 }
212212 with pytest .raises (ValidationError ):
213- CurationModel (** invalid_new_ids )
213+ Curation (** invalid_new_ids )
214214
215215
216216# Test removed units
217217def test_removed_units ():
218218 valid_remove = {"format_version" : "2" , "unit_ids" : [1 , 2 , 3 ], "removed" : [2 ]}
219219
220- model = CurationModel (** valid_remove )
220+ model = Curation (** valid_remove )
221221 assert len (model .removed ) == 1
222222
223223 # Test removing non-existent unit
224224 invalid_remove = {"format_version" : "2" , "unit_ids" : [1 , 2 , 3 ], "removed" : [4 ]} # Non-existent unit
225225 with pytest .raises (ValidationError ):
226- CurationModel (** invalid_remove )
226+ Curation (** invalid_remove )
227227
228228 # Test conflict between merge and remove
229229 invalid_merge_remove = {
@@ -233,7 +233,7 @@ def test_removed_units():
233233 "removed" : [1 ], # Unit is both merged and removed
234234 }
235235 with pytest .raises (ValidationError ):
236- CurationModel (** invalid_merge_remove )
236+ Curation (** invalid_merge_remove )
237237
238238
239239# Test complete model with multiple operations
@@ -251,7 +251,7 @@ def test_complete_model():
251251 "removed" : [5 ],
252252 }
253253
254- model = CurationModel (** complete_model )
254+ model = Curation (** complete_model )
255255 assert model .format_version == "2"
256256 assert len (model .unit_ids ) == 5
257257 assert len (model .label_definitions ) == 2
@@ -274,7 +274,7 @@ def test_complete_model():
274274 "removed" : [5 ],
275275 }
276276
277- model = CurationModel (** complete_model_dict )
277+ model = Curation (** complete_model_dict )
278278 assert model .format_version == "2"
279279 assert len (model .unit_ids ) == 5
280280 assert len (model .label_definitions ) == 2
0 commit comments