@@ -16,24 +16,31 @@ class ManualLabel(BaseModel):
1616
1717
1818class Merge (BaseModel ):
19- merge_unit_group : List [Union [int , str ]] = Field (..., description = "List of groups of units to be merged" )
20- merge_new_unit_id : Optional [Union [int , str ]] = Field (default = None , description = "New unit IDs for the merge group" )
19+ unit_ids : List [Union [int , str ]] = Field (..., description = "List of unit ids to be merged" )
20+ new_unit_id : Optional [Union [int , str ]] = Field (default = None , description = "New unit IDs for the merge group" )
2121
2222
2323class Split (BaseModel ):
2424 unit_id : Union [int , str ] = Field (..., description = "ID of the unit" )
25- split_mode : Literal ["indices" , "labels" ] = Field (
25+ mode : Literal ["indices" , "labels" ] = Field (
2626 default = "indices" ,
2727 description = (
2828 "Mode of the split. The split can be defined by indices or labels. "
2929 "If indices, the split is defined by the a list of lists of indices of spikes within spikes "
30- "belonging to the unit (`split_indices `). "
31- "If labels, the split is defined by a list of labels for each spike (`split_labels `). "
30+ "belonging to the unit (`indices `). "
31+ "If labels, the split is defined by a list of labels for each spike (`labels `). "
3232 ),
3333 )
34- split_indices : Optional [Union [List [List [int ]]]] = Field (default = None , description = "List of indices for the split" )
35- split_labels : Optional [List [int ]] = Field (default = None , description = "List of labels for the split" )
36- split_new_unit_ids : Optional [List [Union [int , str ]]] = Field (
34+ indices : Optional [Union [List [int ], List [List [int ]]]] = Field (
35+ default = None ,
36+ description = (
37+ "List of indices for the split. If a list of indices, the unit is splt in 2 (provided indices/others). "
38+ "If a list of lists, the unit is split in multiple groups (one for each list of indices), plus an optional "
39+ "extra if the spike train has more spikes than the sum of the indices in the lists."
40+ ),
41+ )
42+ labels : Optional [List [int ]] = Field (default = None , description = "List of labels for the split" )
43+ new_unit_ids : Optional [List [Union [int , str ]]] = Field (
3744 default = None , description = "List of new unit IDs for each split"
3845 )
3946
@@ -129,25 +136,36 @@ def check_merges(cls, values):
129136 # Validate merges
130137 for merge in merges :
131138 # Check unit ids exist
132- for unit_id in merge .merge_unit_group :
139+ for unit_id in merge .unit_ids :
133140 if unit_id not in unit_ids :
134141 raise ValueError (f"Merge unit group unit_id { unit_id } is not in the unit list" )
135142
136143 # Check minimum group size
137- if len (merge .merge_unit_group ) < 2 :
144+ if len (merge .unit_ids ) < 2 :
138145 raise ValueError ("Merge unit groups must have at least 2 elements" )
139146
140147 # Check new unit id not already used
141- if merge .merge_new_unit_id is not None :
142- if merge .merge_new_unit_id in unit_ids :
143- raise ValueError (f"New unit ID { merge .merge_new_unit_id } is already in the unit list" )
148+ if merge .new_unit_id is not None :
149+ if merge .new_unit_id in unit_ids :
150+ raise ValueError (f"New unit ID { merge .new_unit_id } is already in the unit list" )
144151
145152 values ["merges" ] = merges
146153 return values
147154
148155 @classmethod
149156 def check_splits (cls , values ):
150-
157+ """
158+ Checks and validates the splits in the curation model.
159+ If `splits` is a dictionary with unit_id as key and split data as values,
160+ it converts it to a list of Split objects.
161+ Each Split object is then validated:
162+ - Checks if the unit_id exists in the unit_ids list.
163+ - Validates the mode (indices or labels).
164+ - If mode is indices, checks that indices are defined and not empty, and that there are no duplicate indices.
165+ - If mode is labels, checks that labels are defined and not empty.
166+ - Validates new unit IDs if provided, ensuring they are not already in the unit_ids list and match the
167+ number of splits.
168+ """
151169 unit_ids = list (values ["unit_ids" ])
152170 splits = values .get ("splits" )
153171 if splits is None :
@@ -162,12 +180,12 @@ def check_splits(cls, values):
162180 split_list .append (
163181 {
164182 "unit_id" : unit_id ,
165- "split_mode " : "indices" ,
166- "split_indices " : [list (indices ) for indices in split_data ],
183+ "mode " : "indices" ,
184+ "indices " : [list (indices ) for indices in split_data ],
167185 }
168186 )
169187 else :
170- split_list .append ({"unit_id" : unit_id , "split_mode " : "labels" , "split_labels " : list (split_data )})
188+ split_list .append ({"unit_id" : unit_id , "mode " : "labels" , "labels " : list (split_data )})
171189 splits = split_list
172190
173191 # Make a copy of the list
@@ -177,12 +195,12 @@ def check_splits(cls, values):
177195 for i , split in enumerate (splits ):
178196 if isinstance (split , dict ):
179197 split = dict (split )
180- if "split_indices " in split :
181- split ["split_indices " ] = [list (indices ) for indices in split ["split_indices " ]]
182- if "split_labels " in split :
183- split ["split_labels " ] = list (split ["split_labels " ])
184- if "split_new_unit_ids " in split :
185- split ["split_new_unit_ids " ] = list (split ["split_new_unit_ids " ])
198+ if "indices " in split :
199+ split ["indices " ] = [list (indices ) for indices in split ["indices " ]]
200+ if "labels " in split :
201+ split ["labels " ] = list (split ["labels " ])
202+ if "new_unit_ids " in split :
203+ split ["new_unit_ids " ] = list (split ["new_unit_ids " ])
186204 splits [i ] = Split (** split )
187205
188206 # Validate splits
@@ -192,36 +210,36 @@ def check_splits(cls, values):
192210 raise ValueError (f"Split unit_id { split .unit_id } is not in the unit list" )
193211
194212 # Validate based on mode
195- if split .split_mode == "indices" :
196- if split .split_indices is None :
197- raise ValueError (f"Split unit { split .unit_id } has no split_indices defined" )
198- if len (split .split_indices ) < 1 :
199- raise ValueError (f"Split unit { split .unit_id } has empty split_indices " )
213+ if split .mode == "indices" :
214+ if split .indices is None :
215+ raise ValueError (f"Split unit { split .unit_id } has no indices defined" )
216+ if len (split .indices ) < 1 :
217+ raise ValueError (f"Split unit { split .unit_id } has empty indices " )
200218 # Check no duplicate indices
201- all_indices = list (chain .from_iterable (split .split_indices ))
219+ all_indices = list (chain .from_iterable (split .indices ))
202220 if len (all_indices ) != len (set (all_indices )):
203221 raise ValueError (f"Split unit { split .unit_id } has duplicate indices" )
204222
205- elif split .split_mode == "labels" :
206- if split .split_labels is None :
207- raise ValueError (f"Split unit { split .unit_id } has no split_labels defined" )
208- if len (split .split_labels ) == 0 :
209- raise ValueError (f"Split unit { split .unit_id } has empty split_labels " )
223+ elif split .mode == "labels" :
224+ if split .labels is None :
225+ raise ValueError (f"Split unit { split .unit_id } has no labels defined" )
226+ if len (split .labels ) == 0 :
227+ raise ValueError (f"Split unit { split .unit_id } has empty labels " )
210228
211229 # Validate new unit IDs
212- if split .split_new_unit_ids is not None :
213- if split .split_mode == "indices" :
214- if len (split .split_new_unit_ids ) != len (split .split_indices ):
230+ if split .new_unit_ids is not None :
231+ if split .mode == "indices" :
232+ if len (split .new_unit_ids ) != len (split .indices ):
215233 raise ValueError (
216234 f"Number of new unit IDs does not match number of splits for unit { split .unit_id } "
217235 )
218- elif split .split_mode == "labels" :
219- if len (split .split_new_unit_ids ) != len (set (split .split_labels )):
236+ elif split .mode == "labels" :
237+ if len (split .new_unit_ids ) != len (set (split .labels )):
220238 raise ValueError (
221239 f"Number of new unit IDs does not match number of unique labels for unit { split .unit_id } "
222240 )
223241
224- for new_id in split .split_new_unit_ids :
242+ for new_id in split .new_unit_ids :
225243 if new_id in unit_ids :
226244 raise ValueError (f"New unit ID { new_id } is already in the unit list" )
227245
@@ -312,7 +330,7 @@ def validate_curation_dict(cls, values):
312330
313331 labeled_unit_set = set ([lbl .unit_id for lbl in values .manual_labels ]) if values .manual_labels else set ()
314332 merged_units_set = (
315- set (chain .from_iterable (merge .merge_unit_group for merge in values .merges )) if values .merges else set ()
333+ set (chain .from_iterable (merge .unit_ids for merge in values .merges )) if values .merges else set ()
316334 )
317335 split_units_set = set (split .unit_id for split in values .splits ) if values .splits else set ()
318336 removed_set = set (values .removed ) if values .removed else set ()
@@ -329,7 +347,7 @@ def validate_curation_dict(cls, values):
329347 raise ValueError ("Curation format: some removed units are not in the unit list" )
330348
331349 # Check for units being merged multiple times
332- all_merging_groups = [set (merge .merge_unit_group ) for merge in values .merges ] if values .merges else []
350+ all_merging_groups = [set (merge .unit_ids ) for merge in values .merges ] if values .merges else []
333351 for gp_1 , gp_2 in combinations (all_merging_groups , 2 ):
334352 if len (gp_1 .intersection (gp_2 )) != 0 :
335353 raise ValueError ("Curation format: some units belong to multiple merge groups" )
0 commit comments