@@ -125,60 +125,14 @@ def apply_curation_labels(
125125 # Please note that manual_labels is done on the unit_ids before the merge!!!
126126 manual_labels = curation_label_to_vectors (curation_model )
127127
128- # apply on non merged / split
129- merge_new_unit_ids = [m .new_unit_id for m in curation_model .merges ]
130- split_new_unit_ids = [m .new_unit_ids for m in curation_model .splits ]
131- split_new_unit_ids = list (chain (* split_new_unit_ids ))
132-
133- merged_split_units = merge_new_unit_ids + split_new_unit_ids
134128 for key , values in manual_labels .items ():
135129 all_values = np .zeros (sorting .unit_ids .size , dtype = values .dtype )
136130 for unit_ind , unit_id in enumerate (sorting .unit_ids ):
137- if unit_id not in merged_split_units :
138- ind = list (curation_model .unit_ids ).index (unit_id )
139- all_values [unit_ind ] = values [ind ]
131+ # if unit_id not in merged_split_units:
132+ ind = list (curation_model .unit_ids ).index (unit_id )
133+ all_values [unit_ind ] = values [ind ]
140134 sorting .set_property (key , all_values )
141135
142- for new_unit_id , merge in zip (merge_new_unit_ids , curation_model .merges ):
143- old_group_ids = merge .unit_ids
144- for label_key , label_def in curation_model .label_definitions .items ():
145- if label_def .exclusive :
146- group_values = []
147- for unit_id in old_group_ids :
148- ind = list (curation_model .unit_ids ).index (unit_id )
149- value = manual_labels [label_key ][ind ]
150- if value != "" :
151- group_values .append (value )
152- if len (set (group_values )) == 1 :
153- # all group has the same label or empty
154- sorting .set_property (key , values = group_values [:1 ], ids = [new_unit_id ])
155- else :
156- for key in label_def .label_options :
157- group_values = []
158- for unit_id in old_group_ids :
159- ind = list (curation_model .unit_ids ).index (unit_id )
160- value = manual_labels [key ][ind ]
161- group_values .append (value )
162- new_value = np .any (group_values )
163- sorting .set_property (key , values = [new_value ], ids = [new_unit_id ])
164-
165- # splits
166- for split in curation_model .splits :
167- # propagate property of splut unit to new units
168- old_unit = split .unit_id
169- new_unit_ids = split .new_unit_ids
170- for label_key , label_def in curation_model .label_definitions .items ():
171- if label_def .exclusive :
172- ind = list (curation_model .unit_ids ).index (old_unit )
173- value = manual_labels [label_key ][ind ]
174- if value != "" :
175- sorting .set_property (label_key , values = [value ] * len (new_unit_ids ), ids = new_unit_ids )
176- else :
177- for key in label_def .label_options :
178- ind = list (curation_model .unit_ids ).index (old_unit )
179- value = manual_labels [key ][ind ]
180- sorting .set_property (key , values = [value ] * len (new_unit_ids ), ids = new_unit_ids )
181-
182136
183137def apply_curation (
184138 sorting_or_analyzer : BaseSorting | SortingAnalyzer ,
@@ -194,10 +148,11 @@ def apply_curation(
194148 Apply curation dict to a Sorting or a SortingAnalyzer.
195149
196150 Steps are done in this order:
197- 1. Apply removal using curation_dict["removed"]
198- 2. Apply merges using curation_dict["merges"]
199- 3. Apply splits using curation_dict["splits"]
200- 4. Set labels using curation_dict["manual_labels"]
151+
152+ 1. Apply labels using curation_dict["manual_labels"]
153+ 2. Apply removal using curation_dict["removed"]
154+ 3. Apply merges using curation_dict["merges"]
155+ 4. Apply splits using curation_dict["splits"]
201156
202157 A new Sorting or SortingAnalyzer (in memory) is returned.
203158 The user (an adult) has the responsability to save it somewhere (or not).
@@ -243,33 +198,36 @@ def apply_curation(
243198 if isinstance (curation_dict_or_model , dict ):
244199 curation_model = CurationModel (** curation_dict_or_model )
245200 else :
246- curation_model = curation_dict_or_model
201+ curation_model = curation_dict_or_model . model_copy ( deep = True )
247202
248203 if not np .array_equal (np .asarray (curation_model .unit_ids ), sorting_or_analyzer .unit_ids ):
249204 raise ValueError ("unit_ids from the curation_dict do not match the one from Sorting or SortingAnalyzer" )
250205
251- # 1. Remove units
206+ # 1. Apply labels
207+ apply_curation_labels (sorting_or_analyzer , curation_model )
208+
209+ # 2. Remove units
252210 if len (curation_model .removed ) > 0 :
253211 curated_sorting_or_analyzer = sorting_or_analyzer .remove_units (curation_model .removed )
254212 else :
255213 curated_sorting_or_analyzer = sorting_or_analyzer
256214
257- # 2 . Merge units
215+ # 3 . Merge units
258216 if len (curation_model .merges ) > 0 :
259217 merge_unit_groups = [m .unit_ids for m in curation_model .merges ]
260218 merge_new_unit_ids = [m .new_unit_id for m in curation_model .merges if m .new_unit_id is not None ]
261219 if len (merge_new_unit_ids ) == 0 :
262220 merge_new_unit_ids = None
263221 if isinstance (sorting_or_analyzer , BaseSorting ):
264- curated_sorting_or_analyzer , _ , new_unit_ids = apply_merges_to_sorting (
222+ curated_sorting_or_analyzer , _ , _ = apply_merges_to_sorting (
265223 curated_sorting_or_analyzer ,
266224 merge_unit_groups = merge_unit_groups ,
267225 censor_ms = censor_ms ,
268226 new_id_strategy = new_id_strategy ,
269227 return_extra = True ,
270228 )
271229 else :
272- curated_sorting_or_analyzer , new_unit_ids = curated_sorting_or_analyzer .merge_units (
230+ curated_sorting_or_analyzer , _ = curated_sorting_or_analyzer .merge_units (
273231 merge_unit_groups = merge_unit_groups ,
274232 censor_ms = censor_ms ,
275233 merging_mode = merging_mode ,
@@ -280,10 +238,8 @@ def apply_curation(
280238 verbose = verbose ,
281239 ** job_kwargs ,
282240 )
283- for i , merge_unit_id in enumerate (new_unit_ids ):
284- curation_model .merges [i ].new_unit_id = merge_unit_id
285241
286- # 3 . Split units
242+ # 4 . Split units
287243 if len (curation_model .splits ) > 0 :
288244 split_units = {}
289245 for split in curation_model .splits :
@@ -297,26 +253,21 @@ def apply_curation(
297253 if len (split_new_unit_ids ) == 0 :
298254 split_new_unit_ids = None
299255 if isinstance (sorting_or_analyzer , BaseSorting ):
300- curated_sorting_or_analyzer , new_unit_ids = apply_splits_to_sorting (
256+ curated_sorting_or_analyzer , _ = apply_splits_to_sorting (
301257 curated_sorting_or_analyzer ,
302258 split_units ,
303259 new_unit_ids = split_new_unit_ids ,
304260 new_id_strategy = new_id_strategy ,
305261 return_extra = True ,
306262 )
307263 else :
308- curated_sorting_or_analyzer , new_unit_ids = curated_sorting_or_analyzer .split_units (
264+ curated_sorting_or_analyzer , _ = curated_sorting_or_analyzer .split_units (
309265 split_units ,
310266 new_id_strategy = new_id_strategy ,
311267 return_new_unit_ids = True ,
312268 new_unit_ids = split_new_unit_ids ,
313269 format = "memory" ,
314270 verbose = verbose ,
315271 )
316- for i , split_unit_ids in enumerate (new_unit_ids ):
317- curation_model .splits [i ].new_unit_ids = split_unit_ids
318-
319- # 4. Apply labels
320- apply_curation_labels (curated_sorting_or_analyzer , curation_model )
321272
322273 return curated_sorting_or_analyzer
0 commit comments