@@ -83,7 +83,6 @@ def __init__(
8383 self .signal_handler = SignalHandler (self , parent = parent )
8484
8585 self .with_traces = with_traces
86- self .main_settings = _default_main_settings .copy ()
8786 self .save_on_compute = save_on_compute
8887 self .verbose = verbose
8988
@@ -95,6 +94,8 @@ def __init__(
9594
9695 self .set_analyzer_info (analyzer )
9796 self .units_table = make_units_table_from_analyzer (self .analyzer , extra_properties = extra_unit_properties )
97+
98+ self .set_curation_info (curation , curation_data , label_definitions , curation_callback , curation_callback_kwargs )
9899
99100 # parse events
100101 self .events = None
@@ -114,81 +115,7 @@ def __init__(
114115 # spikeinterface handle colors in matplotlib style tuple values in range (0,1)
115116 self .refresh_colors ()
116117
117- self .curation = curation
118- self .curation_callback = curation_callback
119- self .curation_callback_kwargs = curation_callback_kwargs
120-
121- self ._potential_merges = None
122- self .curation = curation
123- # TODO: Reload the dictionary if it already exists
124- if self .curation :
125- # rules:
126- # * if user sends curation_data, then it is used
127- # * otherwise, if curation_data already exists in folder it is used
128- # * otherwise create an empty one
129-
130- if curation_data is not None :
131- # validate the curation data
132- format_version = curation_data .get ("format_version" , None )
133- # assume version 2 if not present
134- if format_version is None :
135- raise ValueError ("Curation data format version is missing and is required in the curation data." )
136- try :
137- validate_curation_dict (curation_data )
138- except Exception as e :
139- raise ValueError (f"Invalid curation data.\n Error: { e } " )
140-
141- if curation_data .get ("merges" ) is None :
142- curation_data ["merges" ] = []
143- else :
144- # here we reset the merges for better formatting (str)
145- existing_merges = curation_data ["merges" ]
146- new_merges = []
147- for m in existing_merges :
148- if "unit_ids" not in m :
149- continue
150- if len (m ["unit_ids" ]) < 2 :
151- continue
152- new_merges = add_merge (new_merges , m ["unit_ids" ])
153- curation_data ["merges" ] = new_merges
154- if curation_data .get ("splits" ) is None :
155- curation_data ["splits" ] = []
156- if curation_data .get ("removed" ) is None :
157- curation_data ["removed" ] = []
158-
159- elif self .analyzer .format == "binary_folder" :
160- json_file = self .analyzer .folder / "spikeinterface_gui" / "curation_data.json"
161- if json_file .exists ():
162- with open (json_file , "r" ) as f :
163- curation_data = json .load (f )
164-
165- elif self .analyzer .format == "zarr" :
166- import zarr
167- zarr_root = zarr .open (self .analyzer .folder , mode = 'r' )
168- if "spikeinterface_gui" in zarr_root .keys () and "curation_data" in zarr_root ["spikeinterface_gui" ].attrs .keys ():
169- curation_data = zarr_root ["spikeinterface_gui" ].attrs ["curation_data" ]
170-
171- if curation_data is None :
172- curation_data = deepcopy (empty_curation_data )
173- curation_data ["label_definitions" ] = default_label_definitions .copy ()
174-
175- if curation_data .get ("discard_spikes" ) is None :
176- curation_data ["discard_spikes" ] = []
177-
178- self .curation_data = curation_data
179-
180- if "label_definitions" not in self .curation_data :
181- if label_definitions is not None :
182- self .curation_data ["label_definitions" ] = label_definitions
183-
184- self .has_default_quality_labels = False
185- if "quality" in self .curation_data ["label_definitions" ]:
186- curation_dict_quality_labels = self .curation_data ["label_definitions" ]["quality" ]["label_options" ]
187- default_quality_labels = default_label_definitions ["quality" ]["label_options" ]
188- if set (curation_dict_quality_labels ) == set (default_quality_labels ):
189- if self .verbose :
190- print ('Curation quality labels are the default ones' )
191- self .has_default_quality_labels = True
118+
192119
193120 def check_is_view_possible (self , view_name ):
194121 from .viewlist import get_all_possible_views
@@ -473,30 +400,83 @@ def set_analyzer_info(self, analyzer):
473400 self .update_time_info ()
474401
475402
476- def check_is_view_possible (self , view_name ):
477- from .viewlist import get_all_possible_views
478- possible_class_views = get_all_possible_views ()
479- view_class = possible_class_views [view_name ]
480- if view_class ._depend_on is not None :
481- depencies_ok = all (self .has_extension (k ) for k in view_class ._depend_on )
482- if not depencies_ok :
483- if self .verbose :
484- print (view_name , 'does not have all dependencies' , view_class ._depend_on )
485- return False
486- return True
403+ def set_curation_info (self , curation , curation_data , label_definitions , curation_callback , curation_callback_kwargs ):
487404
488- def declare_a_view (self , new_view ):
489- assert new_view not in self .views , 'view already declared {}' .format (self )
490- self .views .append (new_view )
491- self .signal_handler .connect_view (new_view )
492-
493- @property
494- def channel_ids (self ):
495- return self .analyzer .channel_ids
405+ self .curation = curation
406+ self .curation_callback = curation_callback
407+ self .curation_callback_kwargs = curation_callback_kwargs
496408
497- @property
498- def unit_ids (self ):
499- return self .analyzer .unit_ids
409+ self ._potential_merges = None
410+ self .curation = curation
411+ # TODO: Reload the dictionary if it already exists
412+ if self .curation :
413+ # rules:
414+ # * if user sends curation_data, then it is used
415+ # * otherwise, if curation_data already exists in folder it is used
416+ # * otherwise create an empty one
417+
418+ if curation_data is not None :
419+ # validate the curation data
420+ format_version = curation_data .get ("format_version" , None )
421+ # assume version 2 if not present
422+ if format_version is None :
423+ raise ValueError ("Curation data format version is missing and is required in the curation data." )
424+ try :
425+ validate_curation_dict (curation_data )
426+ except Exception as e :
427+ raise ValueError (f"Invalid curation data.\n Error: { e } " )
428+
429+ if curation_data .get ("merges" ) is None :
430+ curation_data ["merges" ] = []
431+ else :
432+ # here we reset the merges for better formatting (str)
433+ existing_merges = curation_data ["merges" ]
434+ new_merges = []
435+ for m in existing_merges :
436+ if "unit_ids" not in m :
437+ continue
438+ if len (m ["unit_ids" ]) < 2 :
439+ continue
440+ new_merges = add_merge (new_merges , m ["unit_ids" ])
441+ curation_data ["merges" ] = new_merges
442+ if curation_data .get ("splits" ) is None :
443+ curation_data ["splits" ] = []
444+ if curation_data .get ("removed" ) is None :
445+ curation_data ["removed" ] = []
446+
447+ elif self .analyzer .format == "binary_folder" :
448+ json_file = self .analyzer .folder / "spikeinterface_gui" / "curation_data.json"
449+ if json_file .exists ():
450+ with open (json_file , "r" ) as f :
451+ curation_data = json .load (f )
452+
453+ elif self .analyzer .format == "zarr" :
454+ import zarr
455+ zarr_root = zarr .open (self .analyzer .folder , mode = 'r' )
456+ if "spikeinterface_gui" in zarr_root .keys () and "curation_data" in zarr_root ["spikeinterface_gui" ].attrs .keys ():
457+ curation_data = zarr_root ["spikeinterface_gui" ].attrs ["curation_data" ]
458+
459+ if curation_data is None :
460+ curation_data = deepcopy (empty_curation_data )
461+ curation_data ["label_definitions" ] = default_label_definitions .copy ()
462+
463+ if curation_data .get ("discard_spikes" ) is None :
464+ curation_data ["discard_spikes" ] = []
465+
466+ self .curation_data = curation_data
467+
468+ if "label_definitions" not in self .curation_data :
469+ if label_definitions is not None :
470+ self .curation_data ["label_definitions" ] = label_definitions
471+
472+ self .has_default_quality_labels = False
473+ if "quality" in self .curation_data ["label_definitions" ]:
474+ curation_dict_quality_labels = self .curation_data ["label_definitions" ]["quality" ]["label_options" ]
475+ default_quality_labels = default_label_definitions ["quality" ]["label_options" ]
476+ if set (curation_dict_quality_labels ) == set (default_quality_labels ):
477+ if self .verbose :
478+ print ('Curation quality labels are the default ones' )
479+ self .has_default_quality_labels = True
500480
501481 def get_time (self ):
502482 """
@@ -953,7 +933,6 @@ def construct_final_curation(self, with_explicit_new_unit_ids=False):
953933 d ["format_version" ] = "2"
954934 d ["unit_ids" ] = self .unit_ids .tolist ()
955935 d .update (self .curation_data .copy ())
956-
957936 if with_explicit_new_unit_ids :
958937 split_new_id_strategy = self .main_settings .get ('split_new_id_strategy' )
959938 merge_new_id_strategy = self .main_settings .get ('merge_new_id_strategy' )
@@ -972,7 +951,7 @@ def apply_curation(self):
972951 curated_analyzer = apply_curation (self .analyzer , curation )
973952
974953 self .applied_curations .append (curation )
975- self .remove_curation ()
954+ self .remove_curation (curated_analyzer )
976955
977956 self .set_analyzer_info (curated_analyzer )
978957
@@ -984,10 +963,21 @@ def apply_curation(self):
984963 for view in self .views :
985964 view .reinitialize ()
986965
987- def remove_curation (self ):
988- label_definitioins = self .curation_data .get ("label_definitions" , None )
966+ def remove_curation (self , curated_analyzer ):
967+ """Removes curation from the controller, retaining quality labels."""
968+
989969 curation_data = deepcopy (empty_curation_data )
970+ # retain label definitions and 'quality' label
971+ label_definitioins = self .curation_data .get ("label_definitions" , None )
990972 curation_data ["label_definitions" ] = label_definitioins
973+
974+ if (quality_labels := curated_analyzer .get_sorting_property ('quality' )) is not None :
975+ manual_labels = []
976+ for unit_id , quality_label in zip (curated_analyzer .unit_ids , quality_labels ):
977+ manual_labels .append ({'unit_id' : unit_id , 'labels' : {'quality' : [quality_label ]}})
978+
979+ curation_data ['manual_labels' ] = manual_labels
980+
991981 self .curation_data = curation_data
992982
993983 def set_curation_data (self , curation_data ):
0 commit comments